Skip to content

Commit

Permalink
Allow data to drain from PythonParser after connection close.
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Feb 8, 2023
1 parent 5cb5712 commit c769300
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
24 changes: 11 additions & 13 deletions redis/asyncio/connection.py
Expand Up @@ -141,7 +141,7 @@ def decode(self, value: EncodableT, force=False) -> EncodableT:
class BaseParser:
"""Plain Python parsing class"""

__slots__ = "_stream", "_read_size"
__slots__ = "_stream", "_read_size", "_connected"

EXCEPTION_CLASSES: ExceptionMappingT = {
"ERR": {
Expand Down Expand Up @@ -172,6 +172,7 @@ class BaseParser:
def __init__(self, socket_read_size: int):
self._stream: Optional[asyncio.StreamReader] = None
self._read_size = socket_read_size
self._connected = False

def __del__(self):
try:
Expand Down Expand Up @@ -208,7 +209,7 @@ async def read_response(
class PythonParser(BaseParser):
"""Plain Python parsing class"""

__slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
__slots__ = ("encoder", "_buffer", "_pos", "_chunks")

def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
Expand All @@ -226,28 +227,28 @@ def on_connect(self, connection: "Connection"):
self._stream = connection._reader
if self._stream is None:
raise RedisError("Buffer is closed.")

self.encoder = connection.encoder
self._clear()
self._connected = True

def on_disconnect(self):
"""Called when the stream disconnects"""
if self._stream is not None:
self._stream = None
self.encoder = None
self._clear()
self._connected = False

async def can_read_destructive(self) -> bool:
if not self._connected:
raise RedisError("Buffer is closed.")
if self._buffer:
return True
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
async with async_timeout.timeout(0):
return await self._stream.read(1)
except asyncio.TimeoutError:
return False

async def read_response(self, disable_decoding: bool = False):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
Expand All @@ -261,8 +262,6 @@ async def read_response(self, disable_decoding: bool = False):
async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
response: Any
byte, response = raw[:1], raw[1:]
Expand Down Expand Up @@ -350,14 +349,13 @@ async def _readline(self) -> bytes:
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""

__slots__ = BaseParser.__slots__ + ("_reader", "_connected")
__slots__ = ("_reader",)

def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader: Optional[hiredis.Reader] = None
self._connected: bool = False

def on_connect(self, connection: "Connection"):
self._stream = connection._reader
Expand Down
2 changes: 0 additions & 2 deletions tests/test_asyncio/test_connection.py
Expand Up @@ -211,8 +211,6 @@ async def test_connection_disconect_race(parser_class):
This test verifies that a read in progress can finish even
if the `disconnect()` method is called.
"""
if parser_class == PythonParser:
pytest.xfail("doesn't work yet with PythonParser")
if parser_class == HiredisParser and not HIREDIS_AVAILABLE:
pytest.skip("Hiredis not available")

Expand Down

0 comments on commit c769300

Please sign in to comment.