From c0833f60a1d9ec85c589004aba6b6739e6298248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 8 May 2023 10:11:43 +0000 Subject: [PATCH] Optionally disable disconnects in read_response (#2695) * Add regression tests and fixes for issue #1128 * Fix tests for resumable read_response to use "disconnect_on_error" * undo prevision fix attempts in async client and cluster * re-enable cluster test * Suggestions from code review * Add CHANGES --- CHANGES | 1 + redis/asyncio/client.py | 93 ++++++++------------------- redis/asyncio/cluster.py | 33 +++------- redis/asyncio/connection.py | 28 +++++--- redis/client.py | 2 +- redis/connection.py | 24 +++++-- tests/test_asyncio/test_commands.py | 38 +++++++++++ tests/test_asyncio/test_connection.py | 2 +- tests/test_asyncio/test_cwe_404.py | 1 - tests/test_commands.py | 35 ++++++++++ tests/test_connection.py | 2 +- 11 files changed, 149 insertions(+), 110 deletions(-) diff --git a/CHANGES b/CHANGES index 3865ed1067..ea171f6bb5 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Revert #2104, #2673, add `disconnect_on_error` option to `read_response()` (issues #2506, #2624) * Add `address_remap` parameter to `RedisCluster` * Fix incorrect usage of once flag in async Sentinel * asyncio: Fix memory leak caused by hiredis (#2693) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 5fb94b3353..a7b888eee4 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -500,23 +500,6 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): ): raise error - async def _try_send_command_parse_response(self, conn, *args, **options): - try: - return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, args[0], *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - except asyncio.CancelledError: - await conn.disconnect(nowait=True) - raise - finally: - if self.single_connection_client: - self._single_conn_lock.release() - if not self.connection: - await self.connection_pool.release(conn) - # COMMAND EXECUTION AND PROTOCOL PARSING async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" @@ -527,10 +510,18 @@ async def execute_command(self, *args, **options): if self.single_connection_client: await self._single_conn_lock.acquire() - - return await asyncio.shield( - self._try_send_command_parse_response(conn, *args, **options) - ) + try: + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + finally: + if self.single_connection_client: + self._single_conn_lock.release() + if not self.connection: + await pool.release(conn) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options @@ -774,18 +765,10 @@ async def _disconnect_raise_connect(self, conn, error): is not a TimeoutError. Otherwise, try to reconnect """ await conn.disconnect() - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): raise error await conn.connect() - async def _try_execute(self, conn, command, *arg, **kwargs): - try: - return await command(*arg, **kwargs) - except asyncio.CancelledError: - await conn.disconnect() - raise - async def _execute(self, conn, command, *args, **kwargs): """ Connect manually upon disconnection. If the Redis server is down, @@ -794,11 +777,9 @@ async def _execute(self, conn, command, *args, **kwargs): called by the # connection to resubscribe us to any channels and patterns we were previously listening to """ - return await asyncio.shield( - conn.retry.call_with_retry( - lambda: self._try_execute(conn, command, *args, **kwargs), - lambda error: self._disconnect_raise_connect(conn, error), - ) + return await conn.retry.call_with_retry( + lambda: command(*args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), ) async def parse_response(self, block: bool = True, timeout: float = 0): @@ -816,7 +797,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await conn.connect() read_timeout = None if block else timeout - response = await self._execute(conn, conn.read_response, timeout=read_timeout) + response = await self._execute( + conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False + ) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it @@ -1200,18 +1183,6 @@ async def _disconnect_reset_raise(self, conn, error): await self.reset() raise - async def _try_send_command_parse_response(self, conn, *args, **options): - try: - return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, args[0], *args, **options - ), - lambda error: self._disconnect_reset_raise(conn, error), - ) - except asyncio.CancelledError: - await conn.disconnect() - raise - async def immediate_execute_command(self, *args, **options): """ Execute a command immediately, but don't auto-retry on a @@ -1227,8 +1198,12 @@ async def immediate_execute_command(self, *args, **options): command_name, self.shard_hint ) self.connection = conn - return await asyncio.shield( - self._try_send_command_parse_response(conn, *args, **options) + + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), ) def pipeline_execute_command(self, *args, **options): @@ -1396,19 +1371,6 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception): await self.reset() raise - async def _try_execute(self, conn, execute, stack, raise_on_error): - try: - return await conn.retry.call_with_retry( - lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_reset(conn, error), - ) - except asyncio.CancelledError: - # not supposed to be possible, yet here we are - await conn.disconnect(nowait=True) - raise - finally: - await self.reset() - async def execute(self, raise_on_error: bool = True): """Execute all the commands in the current pipeline""" stack = self.command_stack @@ -1430,11 +1392,10 @@ async def execute(self, raise_on_error: bool = True): conn = cast(Connection, conn) try: - return await asyncio.shield( - self._try_execute(conn, execute, stack, raise_on_error) + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), ) - except RuntimeError: - await self.reset() finally: await self.reset() diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index eb5f4db061..929d3e47c7 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1016,33 +1016,12 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any: await connection.send_packed_command(connection.pack_command(*args), False) # Read response - return await asyncio.shield( - self._parse_and_release(connection, args[0], **kwargs) - ) - - async def _parse_and_release(self, connection, *args, **kwargs): try: - return await self.parse_response(connection, *args, **kwargs) - except asyncio.CancelledError: - # should not be possible - await connection.disconnect(nowait=True) - raise + return await self.parse_response(connection, args[0], **kwargs) finally: + # Release connection self._free.append(connection) - async def _try_parse_response(self, cmd, connection, ret): - try: - cmd.result = await asyncio.shield( - self.parse_response(connection, cmd.args[0], **cmd.kwargs) - ) - except asyncio.CancelledError: - await connection.disconnect(nowait=True) - raise - except Exception as e: - cmd.result = e - ret = True - return ret - async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection connection = self.acquire_connection() @@ -1055,7 +1034,13 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Read responses ret = False for cmd in commands: - ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret)) + try: + cmd.result = await self.parse_response( + connection, cmd.args[0], **cmd.kwargs + ) + except Exception as e: + cmd.result = e + ret = True # Release connection self._free.append(connection) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 59f75aa229..462673f2ed 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -804,7 +804,11 @@ async def send_packed_command( raise ConnectionError( f"Error {err_no} while writing to socket. {errmsg}." ) from e - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. await self.disconnect(nowait=True) raise @@ -828,6 +832,8 @@ async def read_response( self, disable_decoding: bool = False, timeout: Optional[float] = None, + *, + disconnect_on_error: bool = True, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout @@ -843,22 +849,24 @@ async def read_response( ) except asyncio.TimeoutError: if timeout is not None: - # user requested timeout, return None + # user requested timeout, return None. Operation can be retried return None # it was a self.socket_timeout error. - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) - except asyncio.CancelledError: - # need this check for 3.7, where CancelledError - # is subclass of Exception, not BaseException - raise - except Exception: - await self.disconnect(nowait=True) + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + await self.disconnect(nowait=True) raise if self.health_check_interval: diff --git a/redis/client.py b/redis/client.py index c43a388358..65d0cec9a9 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1533,7 +1533,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response() + return conn.read_response(disconnect_on_error=False) response = self._execute(conn, try_read) diff --git a/redis/connection.py b/redis/connection.py index 8b2389c6db..5af8928a5d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -834,7 +834,11 @@ def send_packed_command(self, command, check_health=True): errno = e.args[0] errmsg = e.args[1] raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. self.disconnect() raise @@ -859,7 +863,9 @@ def can_read(self, timeout=0): self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") - def read_response(self, disable_decoding=False): + def read_response( + self, disable_decoding=False, *, disconnect_on_error: bool = True + ): """Read the response from a previously sent command""" host_error = self._host_error() @@ -867,15 +873,21 @@ def read_response(self, disable_decoding=False): try: response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise TimeoutError(f"Timeout reading from {host_error}") except OSError as e: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise ConnectionError( f"Error while reading from {host_error}" f" : {e.args}" ) - except Exception: - self.disconnect() + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + self.disconnect() raise if self.health_check_interval: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 409934c9a3..ac3537db52 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1,9 +1,11 @@ """ Tests async overrides of commands from their mixins """ +import asyncio import binascii import datetime import re +import sys from string import ascii_letters import pytest @@ -18,6 +20,11 @@ skip_unless_arch_bits, ) +if sys.version_info >= (3, 11, 3): + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + REDIS_6_VERSION = "5.9.0" @@ -3008,6 +3015,37 @@ async def test_module_list(self, r: redis.Redis): for x in await r.module_list(): assert isinstance(x, dict) + @pytest.mark.onlynoncluster + async def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + ready = asyncio.Event() + + async def helper(): + with pytest.raises(asyncio.CancelledError): + # blocking pop + ready.set() + await r.brpop(["nonexist"]) + # If the following is not done, further Timout operations will fail, + # because the timeout won't catch its Cancelled Error if the task + # has a pending cancel. Python documentation probably should reflect this. + if sys.version_info >= (3, 11): + asyncio.current_task().uncancel() + # if all is well, we can continue. The following should not hang. + await r.set("status", "down") + + task = asyncio.create_task(helper()) + await ready.wait() + await asyncio.sleep(0.01) + # the task is now sleeping, lets send it an exception + task.cancel() + # If all is well, the task should finish right away, otherwise fail with Timeout + async with async_timeout(0.1): + await task + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index e2d77fc1c3..e49dd42204 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -184,7 +184,7 @@ async def test_connection_parse_response_resume(r: redis.Redis): conn._parser._stream = MockStream(message, interrupt_every=2) for i in range(100): try: - response = await conn.read_response() + response = await conn.read_response(disconnect_on_error=False) break except MockStream.TestError: pass diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index d3a0666262..21f2ddde2a 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -128,7 +128,6 @@ async def op(r): assert await r.get("foo") == b"foo" -@pytest.mark.xfail(reason="cancel does not cause disconnect") @pytest.mark.onlynoncluster @pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) async def test_standalone_pipeline(delay, master_host): diff --git a/tests/test_commands.py b/tests/test_commands.py index 4020f5e8a8..cb8966907f 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,9 +1,12 @@ import binascii import datetime import re +import threading import time +from asyncio import CancelledError from string import ascii_letters from unittest import mock +from unittest.mock import patch import pytest @@ -4743,6 +4746,38 @@ def test_psync(self, r): res = r2.psync(r2.client_id(), 1) assert b"FULLRESYNC" in res + @pytest.mark.onlynoncluster + def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + + ok = False + + def helper(): + with pytest.raises(CancelledError): + # blocking pop + with patch.object( + r.connection._parser, "read_response", side_effect=CancelledError + ): + r.brpop(["nonexist"]) + # if all is well, we can continue. + r.set("status", "down") # should not hang + nonlocal ok + ok = True + + thread = threading.Thread(target=helper) + thread.start() + thread.join(0.1) + try: + assert not thread.is_alive() + assert ok + finally: + # disconnect here so that fixture cleanup can proceed + r.connection.disconnect() + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_connection.py b/tests/test_connection.py index 25b4118b2c..75ba738047 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -160,7 +160,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): conn._parser._sock = mock_socket for i in range(100): try: - response = conn.read_response() + response = conn.read_response(disconnect_on_error=False) break except MockSocket.TestError: pass