Skip to content

Commit

Permalink
Optionally disable disconnects in read_response (#2695)
Browse files Browse the repository at this point in the history
* Add regression tests and fixes for issue #1128

* Fix tests for resumable read_response to use "disconnect_on_error"

* undo prevision fix attempts in async client and cluster

* re-enable cluster test

* Suggestions from code review

* Add CHANGES
  • Loading branch information
kristjanvalur committed May 8, 2023
1 parent 093232d commit c0833f6
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 110 deletions.
1 change: 1 addition & 0 deletions CHANGES
@@ -1,3 +1,4 @@
* Revert #2104, #2673, add `disconnect_on_error` option to `read_response()` (issues #2506, #2624)
* Add `address_remap` parameter to `RedisCluster`
* Fix incorrect usage of once flag in async Sentinel
* asyncio: Fix memory leak caused by hiredis (#2693)
Expand Down
93 changes: 27 additions & 66 deletions redis/asyncio/client.py
Expand Up @@ -500,23 +500,6 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, args[0], *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect(nowait=True)
raise
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await self.connection_pool.release(conn)

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
Expand All @@ -527,10 +510,18 @@ async def execute_command(self, *args, **options):

if self.single_connection_client:
await self._single_conn_lock.acquire()

return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down Expand Up @@ -774,18 +765,10 @@ async def _disconnect_raise_connect(self, conn, error):
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()

if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()

async def _try_execute(self, conn, command, *arg, **kwargs):
try:
return await command(*arg, **kwargs)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
Expand All @@ -794,11 +777,9 @@ async def _execute(self, conn, command, *args, **kwargs):
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: self._try_execute(conn, command, *args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)

async def parse_response(self, block: bool = True, timeout: float = 0):
Expand All @@ -816,7 +797,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
await conn.connect()

read_timeout = None if block else timeout
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
response = await self._execute(
conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False
)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down Expand Up @@ -1200,18 +1183,6 @@ async def _disconnect_reset_raise(self, conn, error):
await self.reset()
raise

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, args[0], *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
Expand All @@ -1227,8 +1198,12 @@ async def immediate_execute_command(self, *args, **options):
command_name, self.shard_hint
)
self.connection = conn
return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)

return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)

def pipeline_execute_command(self, *args, **options):
Expand Down Expand Up @@ -1396,19 +1371,6 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
await self.reset()
raise

async def _try_execute(self, conn, execute, stack, raise_on_error):
try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()

async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
Expand All @@ -1430,11 +1392,10 @@ async def execute(self, raise_on_error: bool = True):
conn = cast(Connection, conn)

try:
return await asyncio.shield(
self._try_execute(conn, execute, stack, raise_on_error)
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
except RuntimeError:
await self.reset()
finally:
await self.reset()

Expand Down
33 changes: 9 additions & 24 deletions redis/asyncio/cluster.py
Expand Up @@ -1016,33 +1016,12 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
await connection.send_packed_command(connection.pack_command(*args), False)

# Read response
return await asyncio.shield(
self._parse_and_release(connection, args[0], **kwargs)
)

async def _parse_and_release(self, connection, *args, **kwargs):
try:
return await self.parse_response(connection, *args, **kwargs)
except asyncio.CancelledError:
# should not be possible
await connection.disconnect(nowait=True)
raise
return await self.parse_response(connection, args[0], **kwargs)
finally:
# Release connection
self._free.append(connection)

async def _try_parse_response(self, cmd, connection, ret):
try:
cmd.result = await asyncio.shield(
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
)
except asyncio.CancelledError:
await connection.disconnect(nowait=True)
raise
except Exception as e:
cmd.result = e
ret = True
return ret

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
Expand All @@ -1055,7 +1034,13 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Read responses
ret = False
for cmd in commands:
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
try:
cmd.result = await self.parse_response(
connection, cmd.args[0], **cmd.kwargs
)
except Exception as e:
cmd.result = e
ret = True

# Release connection
self._free.append(connection)
Expand Down
28 changes: 18 additions & 10 deletions redis/asyncio/connection.py
Expand Up @@ -804,7 +804,11 @@ async def send_packed_command(
raise ConnectionError(
f"Error {err_no} while writing to socket. {errmsg}."
) from e
except Exception:
except BaseException:
# BaseExceptions can be raised when a socket send operation is not
# finished, e.g. due to a timeout. Ideally, a caller could then re-try
# to send un-sent data. However, the send_packed_command() API
# does not support it so there is no point in keeping the connection open.
await self.disconnect(nowait=True)
raise

Expand All @@ -828,6 +832,8 @@ async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
*,
disconnect_on_error: bool = True,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
Expand All @@ -843,22 +849,24 @@ async def read_response(
)
except asyncio.TimeoutError:
if timeout is not None:
# user requested timeout, return None
# user requested timeout, return None. Operation can be retried
return None
# it was a self.socket_timeout error.
await self.disconnect(nowait=True)
if disconnect_on_error:
await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
await self.disconnect(nowait=True)
if disconnect_on_error:
await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except asyncio.CancelledError:
# need this check for 3.7, where CancelledError
# is subclass of Exception, not BaseException
raise
except Exception:
await self.disconnect(nowait=True)
except BaseException:
# Also by default close in case of BaseException. A lot of code
# relies on this behaviour when doing Command/Response pairs.
# See #1128.
if disconnect_on_error:
await self.disconnect(nowait=True)
raise

if self.health_check_interval:
Expand Down
2 changes: 1 addition & 1 deletion redis/client.py
Expand Up @@ -1533,7 +1533,7 @@ def try_read():
return None
else:
conn.connect()
return conn.read_response()
return conn.read_response(disconnect_on_error=False)

response = self._execute(conn, try_read)

Expand Down
24 changes: 18 additions & 6 deletions redis/connection.py
Expand Up @@ -834,7 +834,11 @@ def send_packed_command(self, command, check_health=True):
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
except Exception:
except BaseException:
# BaseExceptions can be raised when a socket send operation is not
# finished, e.g. due to a timeout. Ideally, a caller could then re-try
# to send un-sent data. However, the send_packed_command() API
# does not support it so there is no point in keeping the connection open.
self.disconnect()
raise

Expand All @@ -859,23 +863,31 @@ def can_read(self, timeout=0):
self.disconnect()
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")

def read_response(self, disable_decoding=False):
def read_response(
self, disable_decoding=False, *, disconnect_on_error: bool = True
):
"""Read the response from a previously sent command"""

host_error = self._host_error()

try:
response = self._parser.read_response(disable_decoding=disable_decoding)
except socket.timeout:
self.disconnect()
if disconnect_on_error:
self.disconnect()
raise TimeoutError(f"Timeout reading from {host_error}")
except OSError as e:
self.disconnect()
if disconnect_on_error:
self.disconnect()
raise ConnectionError(
f"Error while reading from {host_error}" f" : {e.args}"
)
except Exception:
self.disconnect()
except BaseException:
# Also by default close in case of BaseException. A lot of code
# relies on this behaviour when doing Command/Response pairs.
# See #1128.
if disconnect_on_error:
self.disconnect()
raise

if self.health_check_interval:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_asyncio/test_commands.py
@@ -1,9 +1,11 @@
"""
Tests async overrides of commands from their mixins
"""
import asyncio
import binascii
import datetime
import re
import sys
from string import ascii_letters

import pytest
Expand All @@ -18,6 +20,11 @@
skip_unless_arch_bits,
)

if sys.version_info >= (3, 11, 3):
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout

REDIS_6_VERSION = "5.9.0"


Expand Down Expand Up @@ -3008,6 +3015,37 @@ async def test_module_list(self, r: redis.Redis):
for x in await r.module_list():
assert isinstance(x, dict)

@pytest.mark.onlynoncluster
async def test_interrupted_command(self, r: redis.Redis):
"""
Regression test for issue #1128: An Un-handled BaseException
will leave the socket with un-read response to a previous
command.
"""
ready = asyncio.Event()

async def helper():
with pytest.raises(asyncio.CancelledError):
# blocking pop
ready.set()
await r.brpop(["nonexist"])
# If the following is not done, further Timout operations will fail,
# because the timeout won't catch its Cancelled Error if the task
# has a pending cancel. Python documentation probably should reflect this.
if sys.version_info >= (3, 11):
asyncio.current_task().uncancel()
# if all is well, we can continue. The following should not hang.
await r.set("status", "down")

task = asyncio.create_task(helper())
await ready.wait()
await asyncio.sleep(0.01)
# the task is now sleeping, lets send it an exception
task.cancel()
# If all is well, the task should finish right away, otherwise fail with Timeout
async with async_timeout(0.1):
await task


@pytest.mark.onlynoncluster
class TestBinarySave:
Expand Down

0 comments on commit c0833f6

Please sign in to comment.