diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 95390bd66c..db8025b6f2 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -295,13 +295,18 @@ async def connect(self): """Connects to the Redis server if not already connected""" await self.connect_check_health(check_health=True) - async def connect_check_health(self, check_health: bool = True): + async def connect_check_health( + self, check_health: bool = True, retry_socket_connect: bool = True + ): if self.is_connected: return try: - await self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect() - ) + if retry_socket_connect: + await self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect() + ) + else: + await self._connect() except asyncio.CancelledError: raise # in 3.7 and earlier, this is an Exception, not BaseException except (socket.timeout, asyncio.TimeoutError): diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 0bf7086555..d0455ab6eb 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -11,8 +11,12 @@ SSLConnection, ) from redis.commands import AsyncSentinelCommands -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + ResponseError, + TimeoutError, +) class MasterNotFoundError(ConnectionError): @@ -37,11 +41,10 @@ def __repr__(self): async def connect_to(self, address): self.host, self.port = address - await super().connect() - if self.connection_pool.check_connection: - await self.send_command("PING") - if str_if_bytes(await self.read_response()) != "PONG": - raise ConnectionError("PING failed") + await self.connect_check_health( + check_health=self.connection_pool.check_connection, + retry_socket_connect=False, + ) async def _connect_retry(self): if self._reader: diff --git a/redis/connection.py b/redis/connection.py index d457b1015c..a456514a88 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -378,13 +378,18 @@ def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) - def connect_check_health(self, check_health: bool = True): + def connect_check_health( + self, check_health: bool = True, retry_socket_connect: bool = True + ): if self._sock: return try: - sock = self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect(error) - ) + if retry_socket_connect: + sock = self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect(error) + ) + else: + sock = self._connect() except socket.timeout: raise TimeoutError("Timeout connecting to server") except OSError as e: diff --git a/redis/sentinel.py b/redis/sentinel.py index 198639c932..f12bd8dd5d 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -5,8 +5,12 @@ from redis.client import Redis from redis.commands import SentinelCommands from redis.connection import Connection, ConnectionPool, SSLConnection -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + ResponseError, + TimeoutError, +) class MasterNotFoundError(ConnectionError): @@ -35,11 +39,11 @@ def __repr__(self): def connect_to(self, address): self.host, self.port = address - super().connect() - if self.connection_pool.check_connection: - self.send_command("PING") - if str_if_bytes(self.read_response()) != "PONG": - raise ConnectionError("PING failed") + + self.connect_check_health( + check_health=self.connection_pool.check_connection, + retry_socket_connect=False, + ) def _connect_retry(self): if self._sock: diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index 01f717ee38..5a511b2793 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -33,4 +33,5 @@ async def mock_connect(): conn._connect.side_effect = mock_connect await conn.connect() assert conn._connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 await conn.disconnect() diff --git a/tests/test_sentinel_managed_connection.py b/tests/test_sentinel_managed_connection.py new file mode 100644 index 0000000000..6fe5f7cd5b --- /dev/null +++ b/tests/test_sentinel_managed_connection.py @@ -0,0 +1,34 @@ +import socket + +from redis.retry import Retry +from redis.sentinel import SentinelManagedConnection +from redis.backoff import NoBackoff +from unittest import mock + + +def test_connect_retry_on_timeout_error(master_host): + """Test that the _connect function is retried in case of a timeout""" + connection_pool = mock.Mock() + connection_pool.get_master_address = mock.Mock( + return_value=(master_host[0], master_host[1]) + ) + conn = SentinelManagedConnection( + retry_on_timeout=True, + retry=Retry(NoBackoff(), 3), + connection_pool=connection_pool, + ) + origin_connect = conn._connect + conn._connect = mock.Mock() + + def mock_connect(): + # connect only on the last retry + if conn._connect.call_count <= 2: + raise socket.timeout + else: + return origin_connect() + + conn._connect.side_effect = mock_connect + conn.connect() + assert conn._connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 + conn.disconnect()