From b677703cc0b4c99616c95e2b135701e905c27b47 Mon Sep 17 00:00:00 2001 From: tell-k Date: Sat, 24 Nov 2018 01:15:24 +0900 Subject: [PATCH 1/6] Add search_analyzer. --- elasticsearch_dsl/field.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elasticsearch_dsl/field.py b/elasticsearch_dsl/field.py index 1b79b3293..fe86e07e9 100644 --- a/elasticsearch_dsl/field.py +++ b/elasticsearch_dsl/field.py @@ -365,6 +365,7 @@ class GeoShape(Field): class Completion(Field): _param_defs = { 'analyzer': {'type': 'analyzer'}, + 'search_analyzer': {'type': 'analyzer'}, } name = 'completion' From 9fac705250964a9f65d2286971ee7816dc0cc408 Mon Sep 17 00:00:00 2001 From: tell-k Date: Mon, 31 Dec 2018 01:42:57 +0900 Subject: [PATCH 2/6] Fix suggest query names for Completion suggestor. --- elasticsearch_dsl/search.py | 18 ++++++++-- test_elasticsearch_dsl/test_search.py | 48 +++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index e5e0b8e26..8c404b9d5 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -605,7 +605,7 @@ def highlight(self, *fields, **kwargs): s._highlight[f] = kwargs return s - def suggest(self, name, text, **kwargs): + def suggest(self, name, text=None, regex=None, **kwargs): """ Add a suggestions request to the search. @@ -616,9 +616,23 @@ def suggest(self, name, text, **kwargs): s = Search() s = s.suggest('suggestion-1', 'Elasticsearch', term={'field': 'body'}) + """ + if text is None and regex is None: + raise ValueError('You have to pass "text" or "regex" argument.') + s = self._clone() - s._suggest[name] = {'text': text} + + query_name = 'text' + query_text = text + if 'completion' in kwargs: + if text: + query_name= 'prefix' + elif regex: + query_name= 'regex' + query_text = regex + + s._suggest[name] = {query_name: query_text} s._suggest[name].update(kwargs) return s diff --git a/test_elasticsearch_dsl/test_search.py b/test_elasticsearch_dsl/test_search.py index 53d3c95d6..6d9725ff4 100644 --- a/test_elasticsearch_dsl/test_search.py +++ b/test_elasticsearch_dsl/test_search.py @@ -505,6 +505,54 @@ def test_suggest(): } } == s.to_dict() + +def test_suggest_completion(): + s = search.Search() + s = s.suggest('my_suggestion', 'pyhton', completion={'field': 'title'}) + + assert { + 'suggest': { + 'my_suggestion': { + 'completion': {'field': 'title'}, + 'prefix': 'pyhton' + } + } + } == s.to_dict() + + +def test_suggest_regex_query(): + s = search.Search() + s = s.suggest('my_suggestion', regex='py[hton|py]', completion={'field': 'title'}) + + assert { + 'suggest': { + 'my_suggestion': { + 'completion': {'field': 'title'}, + 'regex': 'py[hton|py]' + } + } + } == s.to_dict() + + +def test_suggest_ignroe_regex_query(): + s = search.Search() + s = s.suggest('my_suggestion', text='python', regex='py[hton|py]', completion={'field': 'title'}) + + assert { + 'suggest': { + 'my_suggestion': { + 'completion': {'field': 'title'}, + 'prefix': 'python' + } + } + } == s.to_dict() + +def test_suggest_value_error(): + s = search.Search() + with raises(ValueError): + s.suggest('my_suggestion', completion={'field': 'title'}) + + def test_exclude(): s = search.Search() s = s.exclude('match', title='python') From 01fffdc1c5f86b4846e526448b98c6e735b337a1 Mon Sep 17 00:00:00 2001 From: tell-k Date: Mon, 14 Jan 2019 21:10:34 +0900 Subject: [PATCH 3/6] Fix that arguments validation is more strictly. --- elasticsearch_dsl/search.py | 28 +++++++++++++++------------ test_elasticsearch_dsl/test_search.py | 27 +++++++++++++------------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 8c404b9d5..5273d1b33 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -611,29 +611,33 @@ def suggest(self, name, text=None, regex=None, **kwargs): :arg name: name of the suggestion :arg text: text to suggest on + :arg regex: regex query for Completion Suggester All keyword arguments will be added to the suggestions body. For example:: s = Search() s = s.suggest('suggestion-1', 'Elasticsearch', term={'field': 'body'}) + # regex query for Completion Suggester + s = Search() + s = s.suggest('suggestion-1', regex='py[thon|py]', completion={'field': 'body'}) + """ if text is None and regex is None: raise ValueError('You have to pass "text" or "regex" argument.') + if text and regex: + raise ValueError('You can only pass either "text" or "regex" argument.') + if regex and 'completion' not in kwargs: + raise ValueError('"regex" argument must be passed with "completion" keyword argument.') s = self._clone() - - query_name = 'text' - query_text = text - if 'completion' in kwargs: - if text: - query_name= 'prefix' - elif regex: - query_name= 'regex' - query_text = regex - - s._suggest[name] = {query_name: query_text} - s._suggest[name].update(kwargs) + d = s._suggest[name] = kwargs + if regex: + d['regex'] = regex + elif 'completion' in kwargs: + d['prefix'] = text + else: + d['text'] = text return s def to_dict(self, count=False, **kwargs): diff --git a/test_elasticsearch_dsl/test_search.py b/test_elasticsearch_dsl/test_search.py index 6d9725ff4..18817dc42 100644 --- a/test_elasticsearch_dsl/test_search.py +++ b/test_elasticsearch_dsl/test_search.py @@ -522,35 +522,34 @@ def test_suggest_completion(): def test_suggest_regex_query(): s = search.Search() - s = s.suggest('my_suggestion', regex='py[hton|py]', completion={'field': 'title'}) + s = s.suggest('my_suggestion', regex='py[thon|py]', completion={'field': 'title'}) assert { 'suggest': { 'my_suggestion': { 'completion': {'field': 'title'}, - 'regex': 'py[hton|py]' + 'regex': 'py[thon|py]' } } } == s.to_dict() -def test_suggest_ignroe_regex_query(): +def test_suggest_must_pass_text_or_regex(): s = search.Search() - s = s.suggest('my_suggestion', text='python', regex='py[hton|py]', completion={'field': 'title'}) + with raises(ValueError): + s.suggest('my_suggestion') + + +def test_suggest_can_only_pass_text_or_regex(): + s = search.Search() + with raises(ValueError): + s.suggest('my_suggestion', text='python', regex='py[hton|py]') - assert { - 'suggest': { - 'my_suggestion': { - 'completion': {'field': 'title'}, - 'prefix': 'python' - } - } - } == s.to_dict() -def test_suggest_value_error(): +def test_suggest_regex_must_be_wtih_completion(): s = search.Search() with raises(ValueError): - s.suggest('my_suggestion', completion={'field': 'title'}) + s.suggest('my_suggestion', regex='py[thon|py]') def test_exclude(): From 90f43caf0c3d542687663060624b0d296d77b86e Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Thu, 6 Jun 2024 20:27:16 +0200 Subject: [PATCH 4/6] refactor: add type hints to wrappers.py --- elasticsearch_dsl/utils.py | 26 ++++++++--- elasticsearch_dsl/wrappers.py | 84 ++++++++++++++++++++++++++++++----- noxfile.py | 2 + tests/test_wrappers.py | 26 ++++++++--- 4 files changed, 116 insertions(+), 22 deletions(-) diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 6e311316d..7fd0b08c7 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,13 +18,25 @@ import collections.abc from copy import copy -from typing import Any, ClassVar, Dict, List, Optional, Type, Union +from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from .exceptions import UnknownDslObject, ValidationException -JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]] +# Usefull types + +JSONType: TypeAlias = Union[ + int, bool, str, float, List["JSONType"], Dict[str, "JSONType"] +] + + +# Type variables for internals + +_KeyT = TypeVar("_KeyT") +_ValT = TypeVar("_ValT") + +# Constants SKIP_VALUES = ("", None) EXPAND__TO_DOT = True @@ -110,18 +122,20 @@ def to_list(self): return self._l_ -class AttrDict: +class AttrDict(Generic[_KeyT, _ValT]): """ Helper class to provide attribute like access (read and write) to dictionaries. Used to provide a convenient way to access both results and nested dsl dicts. """ - def __init__(self, d): + _d_: Dict[_KeyT, _ValT] + + def __init__(self, d: Dict[_KeyT, _ValT]): # assign the inner dict manually to prevent __setattr__ from firing super().__setattr__("_d_", d) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: return key in self._d_ def __nonzero__(self): diff --git a/elasticsearch_dsl/wrappers.py b/elasticsearch_dsl/wrappers.py index 0dbca982f..9009d89d0 100644 --- a/elasticsearch_dsl/wrappers.py +++ b/elasticsearch_dsl/wrappers.py @@ -16,26 +16,78 @@ # under the License. import operator +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Literal, + Mapping, + Optional, + Protocol, + Tuple, + TypeVar, + Union, + cast, +) + +from typing_extensions import TypeAlias from .utils import AttrDict -__all__ = ["Range"] + +class SupportsDunderLT(Protocol): + def __lt__(self, other: Any, /) -> Any: ... + + +class SupportsDunderGT(Protocol): + def __gt__(self, other: Any, /) -> Any: ... + + +class SupportsDunderLE(Protocol): + def __le__(self, other: Any, /) -> Any: ... + + +class SupportsDunderGE(Protocol): + def __ge__(self, other: Any, /) -> Any: ... -class Range(AttrDict): - OPS = { +SupportsComparison: TypeAlias = Union[ + SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT +] + +ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"] +RangeValT = TypeVar("RangeValT", bound=SupportsComparison) + +__all__ = ["Range", "SupportsComparison"] + + +class Range(AttrDict[ComparisonOperators, RangeValT]): + OPS: ClassVar[ + Mapping[ + ComparisonOperators, + Callable[[SupportsComparison, SupportsComparison], bool], + ] + ] = { "lt": operator.lt, "lte": operator.le, "gt": operator.gt, "gte": operator.ge, } - def __init__(self, *args, **kwargs): - if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)): + def __init__( + self, + d: Optional[Dict[ComparisonOperators, RangeValT]] = None, + /, + **kwargs: RangeValT, + ): + if d is not None and (kwargs or not isinstance(d, dict)): raise ValueError( "Range accepts a single dictionary or a set of keyword arguments." ) - data = args[0] if args else kwargs + + # Cast here since mypy is inferring d as an `object` type for some reason + data = cast(Dict[str, RangeValT], d) if d is not None else kwargs for k in data: if k not in self.OPS: @@ -47,22 +99,32 @@ def __init__(self, *args, **kwargs): if "lt" in data and "lte" in data: raise ValueError("You cannot specify both lt and lte for Range.") - super().__init__(args[0] if args else kwargs) + # Here we use cast() since we now the keys are in the allowed values, but mypy does + # not infer it. + super().__init__(cast(Dict[ComparisonOperators, RangeValT], data)) - def __repr__(self): + def __repr__(self) -> str: return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items()) - def __contains__(self, item): + def __contains__(self, item: object) -> bool: if isinstance(item, str): return super().__contains__(item) + item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS) + if not item_supports_comp: + return False + + # Cast to tell mypy whe have checked it and its ok to use the comparison methods + # on `item` + item = cast(SupportsComparison, item) + for op in self.OPS: if op in self._d_ and not self.OPS[op](item, self._d_[op]): return False return True @property - def upper(self): + def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: if "lt" in self._d_: return self._d_["lt"], False if "lte" in self._d_: @@ -70,7 +132,7 @@ def upper(self): return None, False @property - def lower(self): + def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: if "gt" in self._d_: return self._d_["gt"], False if "gte" in self._d_: diff --git a/noxfile.py b/noxfile.py index f90f22f04..a731f99b9 100644 --- a/noxfile.py +++ b/noxfile.py @@ -32,7 +32,9 @@ TYPED_FILES = ( "elasticsearch_dsl/function.py", "elasticsearch_dsl/query.py", + "elasticsearch_dsl/wrappers.py", "tests/test_query.py", + "tests/test_wrappers.py", ) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 454722711..4c8c93f41 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -16,10 +16,12 @@ # under the License. from datetime import datetime, timedelta +from typing import Any, Mapping, Optional, Sequence import pytest from elasticsearch_dsl import Range +from elasticsearch_dsl.wrappers import SupportsComparison @pytest.mark.parametrize( @@ -34,7 +36,9 @@ ({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], ) -def test_range_contains(kwargs, item): +def test_range_contains( + kwargs: Mapping[str, SupportsComparison], item: SupportsComparison +) -> None: assert item in Range(**kwargs) @@ -48,7 +52,9 @@ def test_range_contains(kwargs, item): ({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], ) -def test_range_not_contains(kwargs, item): +def test_range_not_contains( + kwargs: Mapping[str, SupportsComparison], item: SupportsComparison +) -> None: assert item not in Range(**kwargs) @@ -62,7 +68,9 @@ def test_range_not_contains(kwargs, item): ((), {"gt": 1, "gte": 1}), ], ) -def test_range_raises_value_error_on_wrong_params(args, kwargs): +def test_range_raises_value_error_on_wrong_params( + args: Sequence[Any], kwargs: Mapping[str, SupportsComparison] +) -> None: with pytest.raises(ValueError): Range(*args, **kwargs) @@ -76,7 +84,11 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs): (Range(lt=42), None, False), ], ) -def test_range_lower(range, lower, inclusive): +def test_range_lower( + range: Range[SupportsComparison], + lower: Optional[SupportsComparison], + inclusive: bool, +) -> None: assert (lower, inclusive) == range.lower @@ -89,5 +101,9 @@ def test_range_lower(range, lower, inclusive): (Range(gt=42), None, False), ], ) -def test_range_upper(range, upper, inclusive): +def test_range_upper( + range: Range[SupportsComparison], + upper: Optional[SupportsComparison], + inclusive: bool, +) -> None: assert (upper, inclusive) == range.upper From 314829159210da05fa30970e33db65a26fbb2c3e Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Thu, 6 Jun 2024 21:20:06 +0200 Subject: [PATCH 5/6] Revert "refactor: add type hints to wrappers.py" This reverts commit 90f43caf0c3d542687663060624b0d296d77b86e. --- elasticsearch_dsl/utils.py | 26 +++-------- elasticsearch_dsl/wrappers.py | 84 +++++------------------------------ noxfile.py | 2 - tests/test_wrappers.py | 26 +++-------- 4 files changed, 22 insertions(+), 116 deletions(-) diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 7fd0b08c7..6e311316d 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -18,25 +18,13 @@ import collections.abc from copy import copy -from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, Type, Union -from typing_extensions import Self, TypeAlias +from typing_extensions import Self from .exceptions import UnknownDslObject, ValidationException -# Usefull types - -JSONType: TypeAlias = Union[ - int, bool, str, float, List["JSONType"], Dict[str, "JSONType"] -] - - -# Type variables for internals - -_KeyT = TypeVar("_KeyT") -_ValT = TypeVar("_ValT") - -# Constants +JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]] SKIP_VALUES = ("", None) EXPAND__TO_DOT = True @@ -122,20 +110,18 @@ def to_list(self): return self._l_ -class AttrDict(Generic[_KeyT, _ValT]): +class AttrDict: """ Helper class to provide attribute like access (read and write) to dictionaries. Used to provide a convenient way to access both results and nested dsl dicts. """ - _d_: Dict[_KeyT, _ValT] - - def __init__(self, d: Dict[_KeyT, _ValT]): + def __init__(self, d): # assign the inner dict manually to prevent __setattr__ from firing super().__setattr__("_d_", d) - def __contains__(self, key: object) -> bool: + def __contains__(self, key): return key in self._d_ def __nonzero__(self): diff --git a/elasticsearch_dsl/wrappers.py b/elasticsearch_dsl/wrappers.py index 9009d89d0..0dbca982f 100644 --- a/elasticsearch_dsl/wrappers.py +++ b/elasticsearch_dsl/wrappers.py @@ -16,78 +16,26 @@ # under the License. import operator -from typing import ( - Any, - Callable, - ClassVar, - Dict, - Literal, - Mapping, - Optional, - Protocol, - Tuple, - TypeVar, - Union, - cast, -) - -from typing_extensions import TypeAlias from .utils import AttrDict - -class SupportsDunderLT(Protocol): - def __lt__(self, other: Any, /) -> Any: ... - - -class SupportsDunderGT(Protocol): - def __gt__(self, other: Any, /) -> Any: ... - - -class SupportsDunderLE(Protocol): - def __le__(self, other: Any, /) -> Any: ... - - -class SupportsDunderGE(Protocol): - def __ge__(self, other: Any, /) -> Any: ... +__all__ = ["Range"] -SupportsComparison: TypeAlias = Union[ - SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT -] - -ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"] -RangeValT = TypeVar("RangeValT", bound=SupportsComparison) - -__all__ = ["Range", "SupportsComparison"] - - -class Range(AttrDict[ComparisonOperators, RangeValT]): - OPS: ClassVar[ - Mapping[ - ComparisonOperators, - Callable[[SupportsComparison, SupportsComparison], bool], - ] - ] = { +class Range(AttrDict): + OPS = { "lt": operator.lt, "lte": operator.le, "gt": operator.gt, "gte": operator.ge, } - def __init__( - self, - d: Optional[Dict[ComparisonOperators, RangeValT]] = None, - /, - **kwargs: RangeValT, - ): - if d is not None and (kwargs or not isinstance(d, dict)): + def __init__(self, *args, **kwargs): + if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)): raise ValueError( "Range accepts a single dictionary or a set of keyword arguments." ) - - # Cast here since mypy is inferring d as an `object` type for some reason - data = cast(Dict[str, RangeValT], d) if d is not None else kwargs + data = args[0] if args else kwargs for k in data: if k not in self.OPS: @@ -99,32 +47,22 @@ def __init__( if "lt" in data and "lte" in data: raise ValueError("You cannot specify both lt and lte for Range.") - # Here we use cast() since we now the keys are in the allowed values, but mypy does - # not infer it. - super().__init__(cast(Dict[ComparisonOperators, RangeValT], data)) + super().__init__(args[0] if args else kwargs) - def __repr__(self) -> str: + def __repr__(self): return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items()) - def __contains__(self, item: object) -> bool: + def __contains__(self, item): if isinstance(item, str): return super().__contains__(item) - item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS) - if not item_supports_comp: - return False - - # Cast to tell mypy whe have checked it and its ok to use the comparison methods - # on `item` - item = cast(SupportsComparison, item) - for op in self.OPS: if op in self._d_ and not self.OPS[op](item, self._d_[op]): return False return True @property - def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: + def upper(self): if "lt" in self._d_: return self._d_["lt"], False if "lte" in self._d_: @@ -132,7 +70,7 @@ def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: return None, False @property - def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: + def lower(self): if "gt" in self._d_: return self._d_["gt"], False if "gte" in self._d_: diff --git a/noxfile.py b/noxfile.py index a731f99b9..f90f22f04 100644 --- a/noxfile.py +++ b/noxfile.py @@ -32,9 +32,7 @@ TYPED_FILES = ( "elasticsearch_dsl/function.py", "elasticsearch_dsl/query.py", - "elasticsearch_dsl/wrappers.py", "tests/test_query.py", - "tests/test_wrappers.py", ) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 4c8c93f41..454722711 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -16,12 +16,10 @@ # under the License. from datetime import datetime, timedelta -from typing import Any, Mapping, Optional, Sequence import pytest from elasticsearch_dsl import Range -from elasticsearch_dsl.wrappers import SupportsComparison @pytest.mark.parametrize( @@ -36,9 +34,7 @@ ({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], ) -def test_range_contains( - kwargs: Mapping[str, SupportsComparison], item: SupportsComparison -) -> None: +def test_range_contains(kwargs, item): assert item in Range(**kwargs) @@ -52,9 +48,7 @@ def test_range_contains( ({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], ) -def test_range_not_contains( - kwargs: Mapping[str, SupportsComparison], item: SupportsComparison -) -> None: +def test_range_not_contains(kwargs, item): assert item not in Range(**kwargs) @@ -68,9 +62,7 @@ def test_range_not_contains( ((), {"gt": 1, "gte": 1}), ], ) -def test_range_raises_value_error_on_wrong_params( - args: Sequence[Any], kwargs: Mapping[str, SupportsComparison] -) -> None: +def test_range_raises_value_error_on_wrong_params(args, kwargs): with pytest.raises(ValueError): Range(*args, **kwargs) @@ -84,11 +76,7 @@ def test_range_raises_value_error_on_wrong_params( (Range(lt=42), None, False), ], ) -def test_range_lower( - range: Range[SupportsComparison], - lower: Optional[SupportsComparison], - inclusive: bool, -) -> None: +def test_range_lower(range, lower, inclusive): assert (lower, inclusive) == range.lower @@ -101,9 +89,5 @@ def test_range_lower( (Range(gt=42), None, False), ], ) -def test_range_upper( - range: Range[SupportsComparison], - upper: Optional[SupportsComparison], - inclusive: bool, -) -> None: +def test_range_upper(range, upper, inclusive): assert (upper, inclusive) == range.upper From 720467ed31051b3acadcabc36096c65623606492 Mon Sep 17 00:00:00 2001 From: Caio Fontes Date: Thu, 6 Jun 2024 21:37:55 +0200 Subject: [PATCH 6/6] feat: readd regex implementation --- elasticsearch_dsl/search_base.py | 22 ++++++++++++++++-- tests/_async/test_search.py | 40 ++++++++++++++++++++++++++++++++ tests/_sync/test_search.py | 40 ++++++++++++++++++++++++++++++++ 3 files changed, 100 insertions(+), 2 deletions(-) diff --git a/elasticsearch_dsl/search_base.py b/elasticsearch_dsl/search_base.py index 5680778cb..7a940d4ad 100644 --- a/elasticsearch_dsl/search_base.py +++ b/elasticsearch_dsl/search_base.py @@ -743,7 +743,7 @@ def highlight(self, *fields, **kwargs): s._highlight[f] = kwargs return s - def suggest(self, name, text, **kwargs): + def suggest(self, name, text=None, regex=None, **kwargs): """ Add a suggestions request to the search. @@ -754,9 +754,27 @@ def suggest(self, name, text, **kwargs): s = Search() s = s.suggest('suggestion-1', 'Elasticsearch', term={'field': 'body'}) + + # regex query for Completion Suggester + s = Search() + s = s.suggest('suggestion-1', regex='py[thon|py]', completion={'field': 'body'}) """ + if text is None and regex is None: + raise ValueError('You have to pass "text" or "regex" argument.') + if text and regex: + raise ValueError('You can only pass either "text" or "regex" argument.') + if regex and "completion" not in kwargs: + raise ValueError( + '"regex" argument must be passed with "completion" keyword argument.' + ) + s = self._clone() - s._suggest[name] = {"text": text} + if regex: + s._suggest[name] = {"regex": regex} + elif "completion" in kwargs: + s._suggest[name] = {"prefix": text} + else: + s._suggest[name] = {"text": text} s._suggest[name].update(kwargs) return s diff --git a/tests/_async/test_search.py b/tests/_async/test_search.py index 1fc1ac987..6aedb0ed9 100644 --- a/tests/_async/test_search.py +++ b/tests/_async/test_search.py @@ -718,3 +718,43 @@ async def test_empty_search(): assert [hit async for hit in s] == [] assert [hit async for hit in s.scan()] == [] await s.delete() # should not error + + +def test_suggest_completion(): + s = AsyncSearch() + s = s.suggest("my_suggestion", "pyhton", completion={"field": "title"}) + + assert { + "suggest": { + "my_suggestion": {"completion": {"field": "title"}, "prefix": "pyhton"} + } + } == s.to_dict() + + +def test_suggest_regex_query(): + s = AsyncSearch() + s = s.suggest("my_suggestion", regex="py[thon|py]", completion={"field": "title"}) + + assert { + "suggest": { + "my_suggestion": {"completion": {"field": "title"}, "regex": "py[thon|py]"} + } + } == s.to_dict() + + +def test_suggest_must_pass_text_or_regex(): + s = AsyncSearch() + with raises(ValueError): + s.suggest("my_suggestion") + + +def test_suggest_can_only_pass_text_or_regex(): + s = AsyncSearch() + with raises(ValueError): + s.suggest("my_suggestion", text="python", regex="py[hton|py]") + + +def test_suggest_regex_must_be_wtih_completion(): + s = AsyncSearch() + with raises(ValueError): + s.suggest("my_suggestion", regex="py[thon|py]") diff --git a/tests/_sync/test_search.py b/tests/_sync/test_search.py index 9a32ee13c..dcf0fe6f3 100644 --- a/tests/_sync/test_search.py +++ b/tests/_sync/test_search.py @@ -716,3 +716,43 @@ def test_empty_search(): assert [hit for hit in s] == [] assert [hit for hit in s.scan()] == [] s.delete() # should not error + + +def test_suggest_completion(): + s = Search() + s = s.suggest("my_suggestion", "pyhton", completion={"field": "title"}) + + assert { + "suggest": { + "my_suggestion": {"completion": {"field": "title"}, "prefix": "pyhton"} + } + } == s.to_dict() + + +def test_suggest_regex_query(): + s = Search() + s = s.suggest("my_suggestion", regex="py[thon|py]", completion={"field": "title"}) + + assert { + "suggest": { + "my_suggestion": {"completion": {"field": "title"}, "regex": "py[thon|py]"} + } + } == s.to_dict() + + +def test_suggest_must_pass_text_or_regex(): + s = Search() + with raises(ValueError): + s.suggest("my_suggestion") + + +def test_suggest_can_only_pass_text_or_regex(): + s = Search() + with raises(ValueError): + s.suggest("my_suggestion", text="python", regex="py[hton|py]") + + +def test_suggest_regex_must_be_wtih_completion(): + s = Search() + with raises(ValueError): + s.suggest("my_suggestion", regex="py[thon|py]")