From 7dccb43fed9f39ac9aa0b0686a5398d5f9009afa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 15 Dec 2022 14:54:59 +0000 Subject: [PATCH] Make disconnect_on_error optional for Connection.read_response() --- redis/asyncio/client.py | 4 +++- redis/asyncio/connection.py | 21 ++++++++++++--------- redis/client.py | 2 +- redis/connection.py | 16 +++++++++++----- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index abe7d67463..cae966316e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -781,7 +781,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await conn.connect() read_timeout = None if block else timeout - response = await self._execute(conn, conn.read_response, timeout=read_timeout) + response = await self._execute( + conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False + ) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4f19153318..49d08a1c68 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -787,6 +787,7 @@ async def read_response( self, disable_decoding: bool = False, timeout: Optional[float] = None, + disconnect_on_error: bool = True, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout @@ -802,22 +803,24 @@ async def read_response( ) except asyncio.TimeoutError: if timeout is not None: - # user requested timeout, return None + # user requested timeout, return None. Operation can be retried return None # it was a self.socket_timeout error. - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) - except asyncio.CancelledError: - # need this check for 3.7, where CancelledError - # is subclass of Exception, not BaseException - raise - except Exception: - await self.disconnect(nowait=True) + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + await self.disconnect(nowait=True) raise if self.health_check_interval: diff --git a/redis/client.py b/redis/client.py index ed857c8fba..10dd9e596d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1526,7 +1526,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response() + return conn.read_response(disconnect_on_error=False) response = self._execute(conn, try_read) diff --git a/redis/connection.py b/redis/connection.py index 9c5b536f89..bc4802db8e 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -801,7 +801,7 @@ def can_read(self, timeout=0): f"Error while reading from {self.host}:{self.port}: {e.args}" ) - def read_response(self, disable_decoding=False): + def read_response(self, disable_decoding=False, disconnect_on_error: bool = True): """Read the response from a previously sent command""" try: hosterr = f"{self.host}:{self.port}" @@ -811,13 +811,19 @@ def read_response(self, disable_decoding=False): try: response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise TimeoutError(f"Timeout reading from {hosterr}") except OSError as e: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise ConnectionError(f"Error while reading from {hosterr}" f" : {e.args}") - except Exception: - self.disconnect() + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + self.disconnect() raise if self.health_check_interval: