Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PythonRespSerializer: better chunking #3076

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 16 additions & 47 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.compat import Protocol, TypedDict
from redis.connection import DEFAULT_RESP_VERSION
from redis.connection import DEFAULT_RESP_VERSION, HiredisRespSerializer, PythonRespSerializer
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
AuthenticationError,
Expand All @@ -47,7 +47,7 @@
TimeoutError,
)
from redis.typing import EncodableT
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes, HIREDIS_PACK_AVAILABLE

from .._parsers import (
BaseParser,
Expand Down Expand Up @@ -147,6 +147,7 @@ def __init__(
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
command_packer: Optional[Callable[[], None]] = None,
):
if (username or password) and credential_provider is not None:
raise DataError(
Expand Down Expand Up @@ -204,6 +205,7 @@ def __init__(
if p < 2 or p > 3:
raise ConnectionError("protocol must be either 2 or 3")
self.protocol = protocol
self._command_packer = self._construct_command_packer(command_packer)

def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
Expand Down Expand Up @@ -235,6 +237,14 @@ def repr_pieces(self):
def is_connected(self):
return self._reader is not None and self._writer is not None

def _construct_command_packer(self, packer):
if packer is not None:
return packer
elif HIREDIS_PACK_AVAILABLE:
return HiredisRespSerializer()
else:
return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)

def register_connect_callback(self, callback):
"""
Register a callback to be called when the connection is established either
Expand Down Expand Up @@ -495,7 +505,8 @@ async def send_packed_command(
async def send_command(self, *args: Any, **kwargs: Any) -> None:
"""Pack and send a command to the Redis server"""
await self.send_packed_command(
self.pack_command(*args), check_health=kwargs.get("check_health", True)
self._command_packer.pack(*args),
check_health=kwargs.get("check_health", True),
)

async def can_read_destructive(self):
Expand Down Expand Up @@ -569,51 +580,9 @@ async def read_response(
raise response from None
return response

def pack_command(self, *args: EncodableT) -> List[bytes]:
def pack_command(self, *args):
"""Pack a series of arguments into the Redis protocol"""
output = []
# the client might have included 1 or more literal arguments in
# the command name, e.g., 'CONFIG GET'. The Redis server expects these
# arguments to be sent separately, so split the first argument
# manually. These arguments should be bytestrings so that they are
# not encoded.
assert not isinstance(args[0], float)
if isinstance(args[0], str):
args = tuple(args[0].encode().split()) + args[1:]
elif b" " in args[0]:
args = tuple(args[0].split()) + args[1:]

buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))

buffer_cutoff = self._buffer_cutoff
for arg in map(self.encoder.encode, args):
# to avoid large string mallocs, chunk the command into the
# output list if we're sending large values or memoryviews
arg_length = len(arg)
if (
len(buff) > buffer_cutoff
or arg_length > buffer_cutoff
or isinstance(arg, memoryview)
):
buff = SYM_EMPTY.join(
(buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
)
output.append(buff)
output.append(arg)
buff = SYM_CRLF
else:
buff = SYM_EMPTY.join(
(
buff,
SYM_DOLLAR,
str(arg_length).encode(),
SYM_CRLF,
arg,
SYM_CRLF,
)
)
output.append(buff)
return output
return self._command_packer.pack(*args)

def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
"""Pack multiple commands into the Redis protocol"""
Expand Down
39 changes: 17 additions & 22 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,34 +92,29 @@ def pack(self, *args):

buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))

def get_chunks(_args):
for _arg in map(self.encode, _args):
_arg_length = len(_arg)
yield SYM_EMPTY.join((SYM_DOLLAR, str(_arg_length).encode(), SYM_CRLF))
if isinstance(_arg, memoryview):
yield _arg # we yield it independently to avoid copying memoryview
yield SYM_CRLF
else:
yield SYM_EMPTY.join((_arg, SYM_CRLF))

buffer_cutoff = self._buffer_cutoff
for arg in map(self.encode, args):
for arg in get_chunks(args):
# to avoid large string mallocs, chunk the command into the
# output list if we're sending large values or memoryviews
arg_length = len(arg)
if (
len(buff) > buffer_cutoff
or arg_length > buffer_cutoff
or isinstance(arg, memoryview)
):
buff = SYM_EMPTY.join(
(buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
)
# output list if we're sending large values
if len(buff) + len(arg) > buffer_cutoff or isinstance(arg, memoryview):
output.append(buff)
buff = SYM_EMPTY
if isinstance(arg, memoryview):
output.append(arg)
buff = SYM_CRLF
else:
buff = SYM_EMPTY.join(
(
buff,
SYM_DOLLAR,
str(arg_length).encode(),
SYM_CRLF,
arg,
SYM_CRLF,
)
)
buff = SYM_EMPTY.join((buff, arg))
output.append(buff)

return output


Expand Down
53 changes: 52 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import redis
from redis import ConnectionPool, Redis
from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser
from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser, Encoder
from redis.backoff import NoBackoff
from redis.connection import (
Connection,
Expand Down Expand Up @@ -296,3 +296,54 @@ def mock_disconnect(_):

assert called == 1
pool.disconnect()


class TestPythonRespSerializer:
GET_A_ENCODED = b"*2\r\n$3\r\nGET\r\n$1\r\na\r\n"

def test_pack_buffer_cutoff_max(self):
encoder = Encoder("utf-8", "strict", False)
serializer = redis.connection.PythonRespSerializer(buffer_cutoff=1000, encode=encoder.encode)
packed = serializer.pack("GET", "a")
assert packed == [self.GET_A_ENCODED]

def test_pack_buffer_cutoff_average(self):
expected = [
b"*2\r\n$3\r\n",
b"GET\r\n",
b"$1\r\na\r\n",
]
assert b"".join(expected) == self.GET_A_ENCODED

encoder = Encoder("utf-8", "strict", False)
serializer = redis.connection.PythonRespSerializer(buffer_cutoff=8, encode=encoder.encode)
packed = serializer.pack("GET", "a")
assert packed == expected

def test_pack_buffer_cutoff_min(self):
expected = [
b"*2\r\n",
b"$3\r\n",
b"GET\r\n",
b"$1\r\n",
b"a\r\n",
]
assert b"".join(expected) == self.GET_A_ENCODED

encoder = Encoder("utf-8", "strict", False)
serializer = redis.connection.PythonRespSerializer(buffer_cutoff=1, encode=encoder.encode)
packed = serializer.pack("GET", "a")
assert packed == expected

def test_pack_memoryview(self):
expected = [
b"*2\r\n$3\r\n",
b"GET", # memoryview stays independently to avoid copying
b"\r\n$1\r\na\r\n",
]
assert b"".join(expected) == self.GET_A_ENCODED

encoder = Encoder("utf-8", "strict", False)
serializer = redis.connection.PythonRespSerializer(buffer_cutoff=1000, encode=encoder.encode)
packed = serializer.pack(memoryview(b"GET"), "a")
assert packed == expected