Skip to content

Commit

Permalink
Fix bug: client side caching causes unexpected disconnections (async …
Browse files Browse the repository at this point in the history
…version) (#3165)

* fix disconnects

* skip test in cluster

* add test

* save return value from handle_push_response (without it 'read_response' return the push message)

* insert return response from cache to the try block to prevent connection leak

* enable to get connection with data avaliable to read in csc mode and change can_read_destructive to not read data

* fix check if socket is empty (at_eof() can return False but this doesn't mean there's definitely more data to read)

---------

Co-authored-by: Chayim <chayim@users.noreply.github.com>
  • Loading branch information
dvora-h and chayim committed Feb 29, 2024
1 parent c573bc4 commit 26ab964
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 25 deletions.
2 changes: 1 addition & 1 deletion redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def can_read_destructive(self) -> bool:
return True
try:
async with async_timeout(0):
return await self._stream.read(1)
return self._stream.at_eof()
except TimeoutError:
return False

Expand Down
4 changes: 3 additions & 1 deletion redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ async def _read_response(
)
for _ in range(int(response))
]
await self.handle_push_response(response, disable_decoding, push_request)
response = await self.handle_push_response(
response, disable_decoding, push_request
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

Expand Down
40 changes: 21 additions & 19 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,25 +629,27 @@ async def execute_command(self, *args, **options):
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)
response_from_cache = await conn._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
try:
if response_from_cache is not None:
return response_from_cache
else:
try:
if self.single_connection_client:
await self._single_conn_lock.acquire()
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
self._single_conn_lock.release()
finally:
if not self.connection:
await pool.release(conn)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down
14 changes: 10 additions & 4 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]

def _socket_is_empty(self):
"""Check if the socket is empty"""
return not self._reader.at_eof()
return len(self._reader._buffer) == 0

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
Expand Down Expand Up @@ -1192,12 +1192,18 @@ def make_connection(self):
async def ensure_connection(self, connection: AbstractConnection):
"""Ensure that the connection object is connected and valid"""
await connection.connect()
# connections that the pool provides should be ready to send
# a command. if not, the connection was either returned to the
# if client caching is not enabled connections that the pool
# provides should be ready to send a command.
# if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
# (if caching enabled the connection will not always be ready
# to send a command because it may contain invalidation messages)
try:
if await connection.can_read_destructive():
if (
await connection.can_read_destructive()
and connection.client_cache is None
):
raise ConnectionError("Connection has data") from None
except (ConnectionError, OSError):
await connection.disconnect()
Expand Down
42 changes: 42 additions & 0 deletions tests/test_asyncio/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,48 @@ async def test_cache_return_copy(self, r):
check = cache.get(("LRANGE", "mylist", 0, -1))
assert check == [b"baz", b"bar", b"foo"]

@pytest.mark.onlynoncluster
@pytest.mark.parametrize(
"r",
[{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
indirect=True,
)
async def test_csc_not_cause_disconnects(self, r):
r, cache = r
id1 = await r.client_id()
await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1})
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
id2 = await r.client_id()

# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [
"1",
"1",
"1",
"1",
"1",
]

await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2})
id3 = await r.client_id()
# client should get value from redis server post invalidate messages
assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"]

await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3})
# need to check that we get correct value 3 and not 2
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]

await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4})
# need to check that we get correct value 4 and not 3
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
id4 = await r.client_id()
assert id1 == id2 == id3 == id4


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@pytest.mark.onlycluster
Expand Down

0 comments on commit 26ab964

Please sign in to comment.