Skip to content

Commit

Permalink
Make PythonParser resumable (#2510)
Browse files Browse the repository at this point in the history
* PythonParser is now resumable if _stream IO is interrupted

* Add test for parse resumability

* Clear PythonParser state when connection or parsing errors occur.

* disable test for cluster mode.

* Perform "closed" check in a single place.

* Update tests

* Simplify code.

* Remove reduntant test, EOF is detected inside _readline()

* Make syncronous PythonParser restartable on error, same as HiredisParser

Fix sync PythonParser

* Add CHANGES

* isort

* Move MockStream and MockSocket into their own files
  • Loading branch information
kristjanvalur committed Jan 5, 2023
1 parent a947728 commit a9ef0c5
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Make PythonParser resumable in case of error (#2510)
* Add `timeout=None` in `SentinelConnectionManager.read_response`
* Documentation fix: password protected socket connection (#2374)
* Allow `timeout=None` in `PubSub.get_message()` to wait forever
Expand Down
69 changes: 51 additions & 18 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,18 @@ async def read_response(
class PythonParser(BaseParser):
"""Plain Python parsing class"""

__slots__ = BaseParser.__slots__ + ("encoder",)
__slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")

def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
self.encoder: Optional[Encoder] = None
self._buffer = b""
self._chunks = []
self._pos = 0

def _clear(self):
self._buffer = b""
self._chunks.clear()

def on_connect(self, connection: "Connection"):
"""Called when the stream connects"""
Expand All @@ -227,8 +234,11 @@ def on_disconnect(self):
if self._stream is not None:
self._stream = None
self.encoder = None
self._clear()

async def can_read_destructive(self) -> bool:
if self._buffer:
return True
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
Expand All @@ -237,14 +247,23 @@ async def can_read_destructive(self) -> bool:
except asyncio.TimeoutError:
return False

async def read_response(
async def read_response(self, disable_decoding: bool = False):
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(disable_decoding=disable_decoding)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response

async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
response: Any
byte, response = raw[:1], raw[1:]

Expand All @@ -258,6 +277,7 @@ async def read_response(
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
Expand All @@ -282,7 +302,7 @@ async def read_response(
if length == -1:
return None
response = [
(await self.read_response(disable_decoding)) for _ in range(length)
(await self._read_response(disable_decoding)) for _ in range(length)
]
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
Expand All @@ -293,25 +313,38 @@ async def _read(self, length: int) -> bytes:
Read `length` bytes of data. These are assumed to be followed
by a '\r\n' terminator which is subsequently discarded.
"""
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
data = await self._stream.readexactly(length + 2)
except asyncio.IncompleteReadError as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
return data[:-2]
want = length + 2
end = self._pos + want
if len(self._buffer) >= end:
result = self._buffer[self._pos : end - 2]
else:
tail = self._buffer[self._pos :]
try:
data = await self._stream.readexactly(want - len(tail))
except asyncio.IncompleteReadError as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
result = (tail + data)[:-2]
self._chunks.append(data)
self._pos += want
return result

async def _readline(self) -> bytes:
"""
read an unknown number of bytes up to the next '\r\n'
line separator, which is discarded.
"""
if self._stream is None:
raise RedisError("Buffer is closed.")
data = await self._stream.readline()
if not data.endswith(b"\r\n"):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
return data[:-2]
found = self._buffer.find(b"\r\n", self._pos)
if found >= 0:
result = self._buffer[self._pos : found]
else:
tail = self._buffer[self._pos :]
data = await self._stream.readline()
if not data.endswith(b"\r\n"):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
result = (tail + data)[:-2]
self._chunks.append(data)
self._pos += len(result) + 2
return result


class HiredisParser(BaseParser):
Expand Down
58 changes: 42 additions & 16 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,6 @@ def read(self, length):
self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)

# purge the buffer when we've consumed it all so it doesn't
# grow forever
if self.bytes_read == self.bytes_written:
self.purge()

return data[:-2]

def readline(self):
Expand All @@ -251,23 +245,44 @@ def readline(self):
data = buf.readline()

self.bytes_read += len(data)
return data[:-2]

# purge the buffer when we've consumed it all so it doesn't
# grow forever
if self.bytes_read == self.bytes_written:
self.purge()
def get_pos(self):
"""
Get current read position
"""
return self.bytes_read

return data[:-2]
def rewind(self, pos):
"""
Rewind the buffer to a specific position, to re-start reading
"""
self.bytes_read = pos

def purge(self):
self._buffer.seek(0)
self._buffer.truncate()
self.bytes_written = 0
"""
After a successful read, purge the read part of buffer
"""
unread = self.bytes_written - self.bytes_read

# Only if we have read all of the buffer do we truncate, to
# reduce the amount of memory thrashing. This heuristic
# can be changed or removed later.
if unread > 0:
return

if unread > 0:
# move unread data to the front
view = self._buffer.getbuffer()
view[:unread] = view[-unread:]
self._buffer.truncate(unread)
self.bytes_written = unread
self.bytes_read = 0
self._buffer.seek(0)

def close(self):
try:
self.purge()
self.bytes_written = self.bytes_read = 0
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
Expand Down Expand Up @@ -315,6 +330,17 @@ def can_read(self, timeout):
return self._buffer and self._buffer.can_read(timeout)

def read_response(self, disable_decoding=False):
pos = self._buffer.get_pos()
try:
result = self._read_response(disable_decoding=disable_decoding)
except BaseException:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result

def _read_response(self, disable_decoding=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down Expand Up @@ -355,7 +381,7 @@ def read_response(self, disable_decoding=False):
if length == -1:
return None
response = [
self.read_response(disable_decoding=disable_decoding)
self._read_response(disable_decoding=disable_decoding)
for i in range(length)
]
if isinstance(response, bytes) and disable_decoding is False:
Expand Down
41 changes: 41 additions & 0 deletions tests/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Various mocks for testing


class MockSocket:
"""
A class simulating an readable socket, optionally raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

def recv(self, bufsize):
self.tick()
bufsize = min(5, bufsize) # truncate the read size
result = self.data[self.pos : self.pos + bufsize]
self.pos += len(result)
return result

def recv_into(self, buffer, nbytes=0, flags=0):
self.tick()
if nbytes == 0:
nbytes = len(buffer)
nbytes = min(5, nbytes) # truncate the read size
result = self.data[self.pos : self.pos + nbytes]
self.pos += len(result)
buffer[: len(result)] = result
return len(result)
51 changes: 51 additions & 0 deletions tests/test_asyncio/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio

# Helper Mocking classes for the tests.


class MockStream:
"""
A class simulating an asyncio input buffer, optionally raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

async def read(self, want):
self.tick()
want = 5
result = self.data[self.pos : self.pos + want]
self.pos += len(result)
return result

async def readline(self):
self.tick()
find = self.data.find(b"\n", self.pos)
if find >= 0:
result = self.data[self.pos : find + 1]
else:
result = self.data[self.pos :]
self.pos += len(result)
return result

async def readexactly(self, length):
self.tick()
result = self.data[self.pos : self.pos + length]
if len(result) < length:
raise asyncio.IncompleteReadError(result, None)
self.pos += len(result)
return result
48 changes: 41 additions & 7 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import pytest

import redis
from redis.asyncio.connection import (
BaseParser,
Connection,
PythonParser,
UnixDomainSocketConnection,
Expand All @@ -16,23 +18,27 @@
from tests.conftest import skip_if_server_version_lt

from .compat import mock
from .mocks import MockStream


@pytest.mark.onlynoncluster
async def test_invalid_response(create_redis):
r = await create_redis(single_connection_client=True)

raw = b"x"
fake_stream = MockStream(raw + b"\r\n")

parser: "PythonParser" = r.connection._parser
if not isinstance(parser, PythonParser):
pytest.skip("PythonParser only")
stream_mock = mock.Mock(parser._stream)
stream_mock.readline.return_value = raw + b"\r\n"
with mock.patch.object(parser, "_stream", stream_mock):
parser: BaseParser = r.connection._parser
with mock.patch.object(parser, "_stream", fake_stream):
with pytest.raises(InvalidResponse) as cm:
await parser.read_response()
assert str(cm.value) == f"Protocol Error: {raw!r}"
if isinstance(parser, PythonParser):
assert str(cm.value) == f"Protocol Error: {raw!r}"
else:
assert (
str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte'
)
await r.connection.disconnect()


@skip_if_server_version_lt("4.0.0")
Expand Down Expand Up @@ -112,3 +118,31 @@ async def test_connect_timeout_error_without_retry():
await conn.connect()
assert conn._connect.call_count == 1
assert str(e.value) == "Timeout connecting to server"


@pytest.mark.onlynoncluster
async def test_connection_parse_response_resume(r: redis.Redis):
"""
This test verifies that the Connection parser,
be that PythonParser or HiredisParser,
can be interrupted at IO time and then resume parsing.
"""
conn = Connection(**r.connection_pool.connection_kwargs)
await conn.connect()
message = (
b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
)

conn._parser._stream = MockStream(message, interrupt_every=2)
for i in range(100):
try:
response = await conn.read_response()
break
except MockStream.TestError:
pass

else:
pytest.fail("didn't receive a response")
assert response
assert i > 0

0 comments on commit a9ef0c5

Please sign in to comment.