Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def __init__(
self._pool_lock = pool_lock
self._cache = cache
self._cache_lock = threading.RLock()
self._current_command_cache_key = None
self._current_command_cache_entry = None
self._current_options = None
self.register_connect_callback(self._enable_tracking_callback)

Expand Down Expand Up @@ -1453,42 +1453,49 @@ def send_command(self, *args, **kwargs):
if not self._cache.is_cachable(
CacheKey(command=args[0], redis_keys=(), redis_args=())
):
self._current_command_cache_key = None
self._current_command_cache_entry = None
self._conn.send_command(*args, **kwargs)
return

if kwargs.get("keys") is None:
raise ValueError("Cannot create cache key.")

# Creates cache key.
self._current_command_cache_key = CacheKey(
cache_key = CacheKey(
command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
)
self._current_command_cache_entry = None

with self._cache_lock:
# We have to trigger invalidation processing in case if
# it was cached by another connection to avoid
# queueing invalidations in stale connections.
if self._cache.get(self._current_command_cache_key):
entry = self._cache.get(self._current_command_cache_key)

if entry.connection_ref != self._conn:
cache_entry = self._cache.get(cache_key)
if cache_entry is not None and cache_entry.status == CacheEntryStatus.VALID:
# We have to trigger invalidation processing in case if
# it was cached by another connection to avoid
# queueing invalidations in stale connections.
if cache_entry.connection_ref != self._conn:
with self._pool_lock:
while entry.connection_ref.can_read():
entry.connection_ref.read_response(push_request=True)

return
while cache_entry.connection_ref.can_read():
cache_entry.connection_ref.read_response(push_request=True)
# Check if entry still exists.
if self._cache.get(cache_key) is not None:
Copy link

Copilot AI Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check at line 1479 only verifies that some entry exists in the cache, but doesn't verify it's the same entry object as cache_entry. If the entry was invalidated and a new entry was created by another thread between lines 1477 and 1479, the condition would pass but line 1480 would store a reference to the old invalidated entry. This could lead to using stale cached data. The check should verify the entry identity: if self._cache.get(cache_key) is cache_entry:

Suggested change
# Check if entry still exists.
if self._cache.get(cache_key) is not None:
# Check if the same entry still exists in the cache.
if self._cache.get(cache_key) is cache_entry:

Copilot uses AI. Check for mistakes.
self._current_command_cache_entry = cache_entry
return
cache_entry = None
else:
self._current_command_cache_entry = cache_entry
return

# Set temporary entry value to prevent
# race condition from another connection.
self._cache.set(
CacheEntry(
cache_key=self._current_command_cache_key,
if cache_entry is None:
# Creates cache entry.
cache_entry = CacheEntry(
cache_key=cache_key,
cache_value=self.DUMMY_CACHE_VALUE,
status=CacheEntryStatus.IN_PROGRESS,
connection_ref=self._conn,
)
)
# Set temporary entry value to prevent
# race condition from another connection.
self._cache.set(cache_entry)
self._current_command_cache_entry = cache_entry

# Send command over socket only if it's allowed
# read-only command that not yet cached.
Expand All @@ -1501,17 +1508,15 @@ def read_response(
self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
):
with self._cache_lock:
# Check if command response exists in a cache and it's not in progress.
# Check if command response cache entry exists and it's valid.
if (
self._current_command_cache_key is not None
and self._cache.get(self._current_command_cache_key) is not None
and self._cache.get(self._current_command_cache_key).status
!= CacheEntryStatus.IN_PROGRESS
self._current_command_cache_entry is not None
and self._current_command_cache_entry.status == CacheEntryStatus.VALID
):
res = copy.deepcopy(
self._cache.get(self._current_command_cache_key).cache_value
self._current_command_cache_entry.cache_value
)
self._current_command_cache_key = None
self._current_command_cache_entry = None
return res

response = self._conn.read_response(
Expand All @@ -1522,23 +1527,24 @@ def read_response(

with self._cache_lock:
# Prevent not-allowed command from caching.
if self._current_command_cache_key is None:
if self._current_command_cache_entry is None:
return response
# If response is None prevent from caching.
cache_key = self._current_command_cache_entry.cache_key
if response is None:
self._cache.delete_by_cache_keys([self._current_command_cache_key])
self._cache.delete_by_cache_keys([cache_key])
return response

cache_entry = self._cache.get(self._current_command_cache_key)
cache_entry = self._cache.get(cache_key)

# Cache only responses that still valid
# and wasn't invalidated by another connection in meantime.
if cache_entry is not None:
if cache_entry is self._current_command_cache_entry:
cache_entry.status = CacheEntryStatus.VALID
cache_entry.cache_value = response
self._cache.set(cache_entry)

self._current_command_cache_key = None
self._current_command_cache_entry = None

return response

Expand Down
120 changes: 119 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection):
)
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})
assert proxy_connection.read_response() == b"bar"
assert proxy_connection._current_command_cache_key is None
assert proxy_connection._current_command_cache_entry is None
assert proxy_connection.read_response() == b"bar"

mock_cache.set.assert_has_calls(
Expand Down Expand Up @@ -613,3 +613,121 @@ def test_triggers_invalidation_processing_on_another_connection(
assert proxy_connection.read_response() == b"bar"
assert another_conn.can_read.call_count == 2
another_conn.read_response.assert_called_once()


@pytest.mark.skipif(
platform.python_implementation() == "PyPy",
reason="Pypy doesn't support side_effect",
)
def test_cache_entry_in_progress(
self, mock_cache, mock_connection
):
mock_connection.retry = "mock"
mock_connection.host = "mock"
mock_connection.port = "mock"
mock_connection.credential_provider = UsernamePasswordCredentialProvider()

another_conn = copy.deepcopy(mock_connection)
another_conn.can_read.return_value = False
cache_entry = CacheEntry(
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.IN_PROGRESS,
connection_ref=another_conn,
)
mock_cache.is_cachable.return_value = True
mock_cache.get.return_value = cache_entry
mock_connection.can_read.return_value = False
mock_connection.read_response.return_value = b"bar2"

proxy_connection = CacheProxyConnection(
mock_connection, mock_cache, threading.RLock()
)
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})

assert proxy_connection.read_response() == b"bar2"
mock_connection.send_command.assert_called_once()
mock_connection.read_response.assert_called_once()


@pytest.mark.skipif(
platform.python_implementation() == "PyPy",
reason="Pypy doesn't support side_effect",
)
def test_cache_entry_gone_between_send_and_read(
self, mock_cache, mock_connection
):
mock_connection.retry = "mock"
mock_connection.host = "mock"
mock_connection.port = "mock"
mock_connection.credential_provider = UsernamePasswordCredentialProvider()

another_conn = copy.deepcopy(mock_connection)
another_conn.can_read.return_value = False
cache_entry = CacheEntry(
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.VALID,
connection_ref=another_conn,
)
mock_cache.is_cachable.return_value = True
mock_cache.get.return_value = cache_entry
mock_connection.can_read.return_value = False
mock_connection.read_response.return_value = None

proxy_connection = CacheProxyConnection(
mock_connection, mock_cache, threading.RLock()
)
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})

# cache entry gone
mock_cache.get.return_value = None

assert proxy_connection.read_response() == b"bar"
mock_connection.send_command.assert_not_called()
mock_connection.read_response.assert_not_called()


@pytest.mark.skipif(
platform.python_implementation() == "PyPy",
reason="Pypy doesn't support side_effect",
)
def test_cache_entry_fill_between_send_and_read(
self, mock_cache, mock_connection
):
mock_connection.retry = "mock"
mock_connection.host = "mock"
mock_connection.port = "mock"
mock_connection.credential_provider = UsernamePasswordCredentialProvider()

another_conn = copy.deepcopy(mock_connection)
another_conn.can_read.return_value = False

mock_cache.is_cachable.return_value = True
mock_cache.get.return_value = None
mock_connection.can_read.return_value = False
mock_connection.read_response.return_value = b"bar2"

proxy_connection = CacheProxyConnection(
mock_connection, mock_cache, threading.RLock()
)
proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})

cache_entry = CacheEntry(
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.VALID,
connection_ref=another_conn,
)
# cache entry fill
mock_cache.get.return_value = cache_entry

assert proxy_connection.read_response() == b"bar2"
mock_connection.send_command.assert_called_once()
mock_connection.read_response.assert_called_once()
Loading