Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PythonParser resumable #2510

Merged
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Make PythonParser resumable in case of error (#2510)
* Add `timeout=None` in `SentinelConnectionManager.read_response`
* Documentation fix: password protected socket connection (#2374)
* Allow `timeout=None` in `PubSub.get_message()` to wait forever
Expand Down
69 changes: 51 additions & 18 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,18 @@ async def read_response(
class PythonParser(BaseParser):
"""Plain Python parsing class"""

__slots__ = BaseParser.__slots__ + ("encoder",)
__slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
chayim marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
self.encoder: Optional[Encoder] = None
self._buffer = b""
self._chunks = []
self._pos = 0

def _clear(self):
self._buffer = b""
self._chunks.clear()

def on_connect(self, connection: "Connection"):
"""Called when the stream connects"""
Expand All @@ -227,8 +234,11 @@ def on_disconnect(self):
if self._stream is not None:
self._stream = None
self.encoder = None
self._clear()

async def can_read_destructive(self) -> bool:
if self._buffer:
return True
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
Expand All @@ -237,14 +247,23 @@ async def can_read_destructive(self) -> bool:
except asyncio.TimeoutError:
return False

async def read_response(
async def read_response(self, disable_decoding: bool = False):
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(disable_decoding=disable_decoding)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response

async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
response: Any
byte, response = raw[:1], raw[1:]

Expand All @@ -258,6 +277,7 @@ async def read_response(
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
Expand All @@ -282,7 +302,7 @@ async def read_response(
if length == -1:
return None
response = [
(await self.read_response(disable_decoding)) for _ in range(length)
(await self._read_response(disable_decoding)) for _ in range(length)
]
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
Expand All @@ -293,25 +313,38 @@ async def _read(self, length: int) -> bytes:
Read `length` bytes of data. These are assumed to be followed
by a '\r\n' terminator which is subsequently discarded.
"""
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
data = await self._stream.readexactly(length + 2)
except asyncio.IncompleteReadError as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
return data[:-2]
want = length + 2
end = self._pos + want
if len(self._buffer) >= end:
result = self._buffer[self._pos : end - 2]
else:
tail = self._buffer[self._pos :]
try:
data = await self._stream.readexactly(want - len(tail))
except asyncio.IncompleteReadError as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
result = (tail + data)[:-2]
self._chunks.append(data)
self._pos += want
return result

async def _readline(self) -> bytes:
"""
read an unknown number of bytes up to the next '\r\n'
line separator, which is discarded.
"""
if self._stream is None:
raise RedisError("Buffer is closed.")
data = await self._stream.readline()
if not data.endswith(b"\r\n"):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
return data[:-2]
found = self._buffer.find(b"\r\n", self._pos)
if found >= 0:
result = self._buffer[self._pos : found]
else:
tail = self._buffer[self._pos :]
data = await self._stream.readline()
if not data.endswith(b"\r\n"):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
result = (tail + data)[:-2]
self._chunks.append(data)
self._pos += len(result) + 2
return result


class HiredisParser(BaseParser):
Expand Down
58 changes: 42 additions & 16 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,6 @@ def read(self, length):
self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)

# purge the buffer when we've consumed it all so it doesn't
# grow forever
if self.bytes_read == self.bytes_written:
self.purge()

return data[:-2]

def readline(self):
Expand All @@ -251,23 +245,44 @@ def readline(self):
data = buf.readline()

self.bytes_read += len(data)
return data[:-2]

# purge the buffer when we've consumed it all so it doesn't
# grow forever
if self.bytes_read == self.bytes_written:
self.purge()
def get_pos(self):
"""
Get current read position
"""
return self.bytes_read

return data[:-2]
def rewind(self, pos):
"""
Rewind the buffer to a specific position, to re-start reading
"""
self.bytes_read = pos

def purge(self):
self._buffer.seek(0)
self._buffer.truncate()
self.bytes_written = 0
"""
After a successful read, purge the read part of buffer
"""
unread = self.bytes_written - self.bytes_read

# Only if we have read all of the buffer do we truncate, to
# reduce the amount of memory thrashing. This heuristic
# can be changed or removed later.
if unread > 0:
return

if unread > 0:
# move unread data to the front
view = self._buffer.getbuffer()
view[:unread] = view[-unread:]
self._buffer.truncate(unread)
self.bytes_written = unread
self.bytes_read = 0
self._buffer.seek(0)

def close(self):
try:
self.purge()
self.bytes_written = self.bytes_read = 0
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
Expand Down Expand Up @@ -315,6 +330,17 @@ def can_read(self, timeout):
return self._buffer and self._buffer.can_read(timeout)

def read_response(self, disable_decoding=False):
pos = self._buffer.get_pos()
try:
result = self._read_response(disable_decoding=disable_decoding)
except BaseException:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result

def _read_response(self, disable_decoding=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down Expand Up @@ -355,7 +381,7 @@ def read_response(self, disable_decoding=False):
if length == -1:
return None
response = [
self.read_response(disable_decoding=disable_decoding)
self._read_response(disable_decoding=disable_decoding)
for i in range(length)
]
if isinstance(response, bytes) and disable_decoding is False:
Expand Down
41 changes: 41 additions & 0 deletions tests/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Various mocks for testing


class MockSocket:
"""
A class simulating an readable socket, optionally raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

def recv(self, bufsize):
self.tick()
bufsize = min(5, bufsize) # truncate the read size
result = self.data[self.pos : self.pos + bufsize]
self.pos += len(result)
return result

def recv_into(self, buffer, nbytes=0, flags=0):
self.tick()
if nbytes == 0:
nbytes = len(buffer)
nbytes = min(5, nbytes) # truncate the read size
result = self.data[self.pos : self.pos + nbytes]
self.pos += len(result)
buffer[: len(result)] = result
return len(result)
51 changes: 51 additions & 0 deletions tests/test_asyncio/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio

# Helper Mocking classes for the tests.


class MockStream:
"""
A class simulating an asyncio input buffer, optionally raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

async def read(self, want):
self.tick()
want = 5
result = self.data[self.pos : self.pos + want]
self.pos += len(result)
return result

async def readline(self):
self.tick()
find = self.data.find(b"\n", self.pos)
if find >= 0:
result = self.data[self.pos : find + 1]
else:
result = self.data[self.pos :]
self.pos += len(result)
return result

async def readexactly(self, length):
self.tick()
result = self.data[self.pos : self.pos + length]
if len(result) < length:
raise asyncio.IncompleteReadError(result, None)
self.pos += len(result)
return result
48 changes: 41 additions & 7 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import pytest

import redis
from redis.asyncio.connection import (
BaseParser,
Connection,
PythonParser,
UnixDomainSocketConnection,
Expand All @@ -16,23 +18,27 @@
from tests.conftest import skip_if_server_version_lt

from .compat import mock
from .mocks import MockStream


@pytest.mark.onlynoncluster
async def test_invalid_response(create_redis):
r = await create_redis(single_connection_client=True)

raw = b"x"
fake_stream = MockStream(raw + b"\r\n")

parser: "PythonParser" = r.connection._parser
if not isinstance(parser, PythonParser):
pytest.skip("PythonParser only")
stream_mock = mock.Mock(parser._stream)
stream_mock.readline.return_value = raw + b"\r\n"
with mock.patch.object(parser, "_stream", stream_mock):
parser: BaseParser = r.connection._parser
with mock.patch.object(parser, "_stream", fake_stream):
with pytest.raises(InvalidResponse) as cm:
await parser.read_response()
assert str(cm.value) == f"Protocol Error: {raw!r}"
if isinstance(parser, PythonParser):
assert str(cm.value) == f"Protocol Error: {raw!r}"
else:
assert (
str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte'
)
await r.connection.disconnect()


@skip_if_server_version_lt("4.0.0")
Expand Down Expand Up @@ -112,3 +118,31 @@ async def test_connect_timeout_error_without_retry():
await conn.connect()
assert conn._connect.call_count == 1
assert str(e.value) == "Timeout connecting to server"


@pytest.mark.onlynoncluster
async def test_connection_parse_response_resume(r: redis.Redis):
"""
This test verifies that the Connection parser,
be that PythonParser or HiredisParser,
can be interrupted at IO time and then resume parsing.
"""
conn = Connection(**r.connection_pool.connection_kwargs)
await conn.connect()
message = (
b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
)

conn._parser._stream = MockStream(message, interrupt_every=2)
for i in range(100):
try:
response = await conn.read_response()
break
except MockStream.TestError:
pass

else:
pytest.fail("didn't receive a response")
assert response
assert i > 0