diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index dcc21b5a9c..7d0402e9d5 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -125,7 +125,7 @@ jobs: redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}' ] python-version: [ '3.8', '3.13'] parser-backend: [ 'hiredis' ] - hiredis-version: [ '>=3.0.0', '<3.0.0' ] + hiredis-version: [ '>=3.2.0', '<3.0.0' ] event-loop: [ 'asyncio' ] env: ACTIONS_ALLOW_UNSECURE_COMMANDS: true diff --git a/pyproject.toml b/pyproject.toml index 5cd40c0212..ac692ef9d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = ['async-timeout>=4.0.3; python_full_version<"3.11.3"'] [project.optional-dependencies] hiredis = [ - "hiredis>=3.0.0", + "hiredis>=3.2.0", ] ocsp = [ "cryptography>=36.0.1", diff --git a/redis/_parsers/__init__.py b/redis/_parsers/__init__.py index 6cc32e3cae..30cb1cd5b9 100644 --- a/redis/_parsers/__init__.py +++ b/redis/_parsers/__init__.py @@ -1,4 +1,9 @@ -from .base import BaseParser, _AsyncRESPBase +from .base import ( + AsyncPushNotificationsParser, + BaseParser, + PushNotificationsParser, + _AsyncRESPBase, +) from .commands import AsyncCommandsParser, CommandsParser from .encoders import Encoder from .hiredis import _AsyncHiredisParser, _HiredisParser @@ -11,10 +16,12 @@ "_AsyncRESPBase", "_AsyncRESP2Parser", "_AsyncRESP3Parser", + "AsyncPushNotificationsParser", "CommandsParser", "Encoder", "BaseParser", "_HiredisParser", "_RESP2Parser", "_RESP3Parser", + "PushNotificationsParser", ] diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index ebc8313ce7..69d7b585dd 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -1,7 +1,7 @@ import sys from abc import ABC from asyncio import IncompleteReadError, StreamReader, TimeoutError -from typing import List, Optional, Union +from typing import Callable, List, Optional, Protocol, Union if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout @@ -158,6 +158,58 @@ async def read_response( raise NotImplementedError() +_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] + + +class PushNotificationsParser(Protocol): + """Protocol defining RESP3-specific parsing functionality""" + + pubsub_push_handler_func: Callable + invalidation_push_handler_func: Optional[Callable] = None + + def handle_pubsub_push_response(self, response): + """Handle pubsub push responses""" + raise NotImplementedError() + + def handle_push_response(self, response, **kwargs): + if response[0] not in _INVALIDATION_MESSAGE: + return self.pubsub_push_handler_func(response) + if self.invalidation_push_handler_func: + return self.invalidation_push_handler_func(response) + + def set_pubsub_push_handler(self, pubsub_push_handler_func): + self.pubsub_push_handler_func = pubsub_push_handler_func + + def set_invalidation_push_handler(self, invalidation_push_handler_func): + self.invalidation_push_handler_func = invalidation_push_handler_func + + +class AsyncPushNotificationsParser(Protocol): + """Protocol defining async RESP3-specific parsing functionality""" + + pubsub_push_handler_func: Callable + invalidation_push_handler_func: Optional[Callable] = None + + async def handle_pubsub_push_response(self, response): + """Handle pubsub push responses asynchronously""" + raise NotImplementedError() + + async def handle_push_response(self, response, **kwargs): + """Handle push responses asynchronously""" + if response[0] not in _INVALIDATION_MESSAGE: + return await self.pubsub_push_handler_func(response) + if self.invalidation_push_handler_func: + return await self.invalidation_push_handler_func(response) + + def set_pubsub_push_handler(self, pubsub_push_handler_func): + """Set the pubsub push handler function""" + self.pubsub_push_handler_func = pubsub_push_handler_func + + def set_invalidation_push_handler(self, invalidation_push_handler_func): + """Set the invalidation push handler function""" + self.invalidation_push_handler_func = invalidation_push_handler_func + + class _AsyncRESPBase(AsyncBaseParser): """Base class for async resp parsing""" diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index c807bd903a..521a58b26c 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -1,6 +1,7 @@ import asyncio import socket import sys +from logging import getLogger from typing import Callable, List, Optional, TypedDict, Union if sys.version_info.major >= 3 and sys.version_info.minor >= 11: @@ -11,7 +12,12 @@ from ..exceptions import ConnectionError, InvalidResponse, RedisError from ..typing import EncodableT from ..utils import HIREDIS_AVAILABLE -from .base import AsyncBaseParser, BaseParser +from .base import ( + AsyncBaseParser, + AsyncPushNotificationsParser, + BaseParser, + PushNotificationsParser, +) from .socket import ( NONBLOCKING_EXCEPTION_ERROR_NUMBERS, NONBLOCKING_EXCEPTIONS, @@ -32,7 +38,7 @@ class _HiredisReaderArgs(TypedDict, total=False): errors: Optional[str] -class _HiredisParser(BaseParser): +class _HiredisParser(BaseParser, PushNotificationsParser): "Parser class for connections using Hiredis" def __init__(self, socket_read_size): @@ -40,6 +46,9 @@ def __init__(self, socket_read_size): raise RedisError("Hiredis is not installed") self.socket_read_size = socket_read_size self._buffer = bytearray(socket_read_size) + self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.invalidation_push_handler_func = None + self._hiredis_PushNotificationType = None def __del__(self): try: @@ -47,6 +56,11 @@ def __del__(self): except Exception: pass + def handle_pubsub_push_response(self, response): + logger = getLogger("push_response") + logger.debug("Push response: " + str(response)) + return response + def on_connect(self, connection, **kwargs): import hiredis @@ -64,6 +78,12 @@ def on_connect(self, connection, **kwargs): self._reader = hiredis.Reader(**kwargs) self._next_response = NOT_ENOUGH_DATA + try: + self._hiredis_PushNotificationType = hiredis.PushNotification + except AttributeError: + # hiredis < 3.2 + self._hiredis_PushNotificationType = None + def on_disconnect(self): self._sock = None self._reader = None @@ -109,7 +129,7 @@ def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): if custom_timeout: sock.settimeout(self._socket_timeout) - def read_response(self, disable_decoding=False): + def read_response(self, disable_decoding=False, push_request=False): if not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) @@ -117,6 +137,16 @@ def read_response(self, disable_decoding=False): if self._next_response is not NOT_ENOUGH_DATA: response = self._next_response self._next_response = NOT_ENOUGH_DATA + if self._hiredis_PushNotificationType is not None and isinstance( + response, self._hiredis_PushNotificationType + ): + response = self.handle_push_response(response) + if not push_request: + return self.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response return response if disable_decoding: @@ -135,6 +165,16 @@ def read_response(self, disable_decoding=False): # happened if isinstance(response, ConnectionError): raise response + elif self._hiredis_PushNotificationType is not None and isinstance( + response, self._hiredis_PushNotificationType + ): + response = self.handle_push_response(response) + if not push_request: + return self.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response elif ( isinstance(response, list) and response @@ -144,7 +184,7 @@ def read_response(self, disable_decoding=False): return response -class _AsyncHiredisParser(AsyncBaseParser): +class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser): """Async implementation of parser class for connections using Hiredis""" __slots__ = ("_reader",) @@ -154,6 +194,14 @@ def __init__(self, socket_read_size: int): raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) self._reader = None + self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.invalidation_push_handler_func = None + self._hiredis_PushNotificationType = None + + async def handle_pubsub_push_response(self, response): + logger = getLogger("push_response") + logger.debug("Push response: " + str(response)) + return response def on_connect(self, connection): import hiredis @@ -171,6 +219,14 @@ def on_connect(self, connection): self._reader = hiredis.Reader(**kwargs) self._connected = True + try: + self._hiredis_PushNotificationType = getattr( + hiredis, "PushNotification", None + ) + except AttributeError: + # hiredis < 3.2 + self._hiredis_PushNotificationType = None + def on_disconnect(self): self._connected = False @@ -195,7 +251,7 @@ async def read_from_socket(self): return True async def read_response( - self, disable_decoding: bool = False + self, disable_decoding: bool = False, push_request: bool = False ) -> Union[EncodableT, List[EncodableT]]: # If `on_disconnect()` has been called, prohibit any more reads # even if they could happen because data might be present. @@ -207,6 +263,7 @@ async def read_response( response = self._reader.gets(False) else: response = self._reader.gets() + while response is NOT_ENOUGH_DATA: await self.read_from_socket() if disable_decoding: @@ -219,6 +276,16 @@ async def read_response( # happened if isinstance(response, ConnectionError): raise response + elif self._hiredis_PushNotificationType is not None and isinstance( + response, self._hiredis_PushNotificationType + ): + response = await self.handle_push_response(response) + if not push_request: + return await self.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response elif ( isinstance(response, list) and response diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index ce4c59fb5b..42c6652e31 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -3,13 +3,16 @@ from ..exceptions import ConnectionError, InvalidResponse, ResponseError from ..typing import EncodableT -from .base import _AsyncRESPBase, _RESPBase +from .base import ( + AsyncPushNotificationsParser, + PushNotificationsParser, + _AsyncRESPBase, + _RESPBase, +) from .socket import SERVER_CLOSED_CONNECTION_ERROR -_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] - -class _RESP3Parser(_RESPBase): +class _RESP3Parser(_RESPBase, PushNotificationsParser): """RESP3 protocol implementation""" def __init__(self, socket_read_size): @@ -113,9 +116,7 @@ def _read_response(self, disable_decoding=False, push_request=False): ) for _ in range(int(response)) ] - response = self.handle_push_response( - response, disable_decoding, push_request - ) + response = self.handle_push_response(response) if not push_request: return self._read_response( disable_decoding=disable_decoding, push_request=push_request @@ -129,20 +130,8 @@ def _read_response(self, disable_decoding=False, push_request=False): response = self.encoder.decode(response) return response - def handle_push_response(self, response, disable_decoding, push_request): - if response[0] not in _INVALIDATION_MESSAGE: - return self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: - return self.invalidation_push_handler_func(response) - - def set_pubsub_push_handler(self, pubsub_push_handler_func): - self.pubsub_push_handler_func = pubsub_push_handler_func - - def set_invalidation_push_handler(self, invalidation_push_handler_func): - self.invalidation_push_handler_func = invalidation_push_handler_func - -class _AsyncRESP3Parser(_AsyncRESPBase): +class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser): def __init__(self, socket_read_size): super().__init__(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response @@ -253,9 +242,7 @@ async def _read_response( ) for _ in range(int(response)) ] - response = await self.handle_push_response( - response, disable_decoding, push_request - ) + response = await self.handle_push_response(response) if not push_request: return await self._read_response( disable_decoding=disable_decoding, push_request=push_request @@ -268,15 +255,3 @@ async def _read_response( if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) return response - - async def handle_push_response(self, response, disable_decoding, push_request): - if response[0] not in _INVALIDATION_MESSAGE: - return await self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: - return await self.invalidation_push_handler_func(response) - - def set_pubsub_push_handler(self, pubsub_push_handler_func): - self.pubsub_push_handler_func = pubsub_push_handler_func - - def set_invalidation_push_handler(self, invalidation_push_handler_func): - self.invalidation_push_handler_func = invalidation_push_handler_func diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 1cb28e725e..aac409073f 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -70,7 +70,6 @@ ) from redis.typing import ChannelT, EncodableT, KeyT from redis.utils import ( - HIREDIS_AVAILABLE, SSL_AVAILABLE, _set_info_logger, deprecated_args, @@ -938,7 +937,7 @@ async def connect(self): self.connection.register_connect_callback(self.on_connect) else: await self.connection.connect() - if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + if self.push_handler_func is not None: self.connection._parser.set_pubsub_push_handler(self.push_handler_func) self._event_dispatcher.dispatch( diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d1ae81d269..d6c03f17c5 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -576,11 +576,7 @@ async def read_response( read_timeout = timeout if timeout is not None else self.socket_timeout host_error = self._host_error() try: - if ( - read_timeout is not None - and self.protocol in ["3", 3] - and not HIREDIS_AVAILABLE - ): + if read_timeout is not None and self.protocol in ["3", 3]: async with async_timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding, push_request=push_request @@ -590,7 +586,7 @@ async def read_response( response = await self._parser.read_response( disable_decoding=disable_decoding ) - elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: + elif self.protocol in ["3", 3]: response = await self._parser.read_response( disable_decoding=disable_decoding, push_request=push_request ) diff --git a/redis/client.py b/redis/client.py index dc4f0f9d0c..c662e591e0 100755 --- a/redis/client.py +++ b/redis/client.py @@ -58,7 +58,6 @@ from redis.lock import Lock from redis.retry import Retry from redis.utils import ( - HIREDIS_AVAILABLE, _set_info_logger, deprecated_args, get_lib_version, @@ -861,7 +860,7 @@ def execute_command(self, *args): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) - if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + if self.push_handler_func is not None: self.connection._parser.set_pubsub_push_handler(self.push_handler_func) self._event_dispatcher.dispatch( AfterPubSubConnectionInstantiationEvent( diff --git a/redis/cluster.py b/redis/cluster.py index af60e1c76c..bbc4204bdf 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -51,7 +51,6 @@ from redis.lock import Lock from redis.retry import Retry from redis.utils import ( - HIREDIS_AVAILABLE, deprecated_args, dict_merge, list_keys_to_dict, @@ -1999,7 +1998,7 @@ def execute_command(self, *args): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) - if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + if self.push_handler_func is not None: self.connection._parser.set_pubsub_push_handler(self.push_handler_func) self._event_dispatcher.dispatch( AfterPubSubConnectionInstantiationEvent( diff --git a/redis/connection.py b/redis/connection.py index cc805e442f..0884bca362 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -636,7 +636,7 @@ def read_response( host_error = self._host_error() try: - if self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: + if self.protocol in ["3", 3]: response = self._parser.read_response( disable_decoding=disable_decoding, push_request=push_request ) diff --git a/redis/utils.py b/redis/utils.py index 1f0b24d768..715913e914 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -12,9 +12,12 @@ import hiredis # noqa # Only support Hiredis >= 3.0: - HIREDIS_AVAILABLE = int(hiredis.__version__.split(".")[0]) >= 3 + hiredis_version = hiredis.__version__.split(".") + HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( + int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 + ) if not HIREDIS_AVAILABLE: - raise ImportError("hiredis package should be >= 3.0.0") + raise ImportError("hiredis package should be >= 3.2.0") except ImportError: HIREDIS_AVAILABLE = False diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 13a6158b40..5ea89a0b8b 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -17,7 +17,6 @@ import redis.asyncio as redis from redis.exceptions import ConnectionError from redis.typing import EncodableT -from redis.utils import HIREDIS_AVAILABLE from tests.conftest import get_protocol_version, skip_if_server_version_lt from .compat import aclosing, create_task, mock @@ -464,7 +463,6 @@ class TestPubSubRESP3Handler: async def my_handler(self, message): self.message = ["my handler", message] - @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") async def test_push_handler(self, r): if get_protocol_version(r) in [2, "2", None]: return diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 2ee74f710f..932ece59b8 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -7,6 +7,7 @@ import numpy as np import pytest import pytest_asyncio +from redis import ResponseError import redis.asyncio as redis import redis.commands.search.aggregation as aggregations import redis.commands.search.reducers as reducers @@ -60,6 +61,10 @@ async def waitForIndex(env, idx, timeout=None): break except ValueError: break + except ResponseError: + # index doesn't exist yet + # continue to sleep and try again + pass await asyncio.sleep(delay) if timeout is not None: diff --git a/tests/test_cache.py b/tests/test_cache.py index a305d2de7b..1f3193c49d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -13,7 +13,6 @@ EvictionPolicyType, LRUPolicy, ) -from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client, skip_if_resp_version, skip_if_server_version_lt @@ -40,7 +39,6 @@ def r(request): yield client -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") @@ -124,6 +122,10 @@ def test_get_from_default_cache(self, r, r2): ] # change key in redis (cause invalidation) r2.set("foo", "barbar") + + # Add a small delay to allow invalidation to be processed + time.sleep(0.1) + # Retrieves a new value from server and cache it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached @@ -325,7 +327,6 @@ def test_cache_flushed_on_server_flush(self, r): assert cache.size == 0 -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster @skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") @@ -568,7 +569,6 @@ def test_cache_flushed_on_server_flush(self, r, cache): assert cache.size == 0 -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") @@ -674,7 +674,6 @@ def test_cache_clears_on_disconnect(self, master, cache): assert cache.size == 0 -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index d97c9063ac..9e67659fa9 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -9,7 +9,7 @@ import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool -from redis.utils import HIREDIS_AVAILABLE, SSL_AVAILABLE +from redis.utils import SSL_AVAILABLE from .conftest import ( _get_client, @@ -217,7 +217,6 @@ def test_repr_contains_db_info_unix(self): expected = "path=abc,db=0,client_name=test-client" assert expected in repr(pool) - @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 9ead455af3..ac6965a188 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -10,7 +10,6 @@ import pytest import redis from redis.exceptions import ConnectionError -from redis.utils import HIREDIS_AVAILABLE from .conftest import ( _get_client, @@ -593,7 +592,6 @@ class TestPubSubRESP3Handler: def my_handler(self, message): self.message = ["my handler", message] - @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") def test_push_handler(self, r): if is_resp2_connection(r): return @@ -605,7 +603,6 @@ def test_push_handler(self, r): assert wait_for_message(p) is None assert self.message == ["my handler", [b"message", b"foo", b"test message"]] - @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @skip_if_server_version_lt("7.0.0") def test_push_handler_sharded_pubsub(self, r): if is_resp2_connection(r): diff --git a/tests/test_search.py b/tests/test_search.py index 7e4f59eb79..4af55e8a17 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -6,6 +6,7 @@ import numpy as np import pytest +from redis import ResponseError import redis import redis.commands.search import redis.commands.search.aggregation as aggregations @@ -47,8 +48,8 @@ def waitForIndex(env, idx, timeout=None): delay = 0.1 while True: - res = env.execute_command("FT.INFO", idx) try: + res = env.execute_command("FT.INFO", idx) if int(res[res.index("indexing") + 1]) == 0: break except ValueError: @@ -59,6 +60,10 @@ def waitForIndex(env, idx, timeout=None): break except ValueError: break + except ResponseError: + # index doesn't exist yet + # continue to sleep and try again + pass time.sleep(delay) if timeout is not None: @@ -1909,6 +1914,8 @@ def test_binary_and_text_fields(client): ), ) + waitForIndex(client, index_name) + query = ( Query("*") .return_field("vector_emb", decode_field=False)