Skip to content

Commit

Permalink
PythonRespSerializer: better chunking
Browse files Browse the repository at this point in the history
Previous implementation made no difference between memoryviews, that must
be sent separately to avoid copying, and just reaching buffer_cutoff.
That resulted in some single values being sent separately for no reason,
so a chunk could be just a single encoded integer, even though the
buffer_cutoff is pretty large.
  • Loading branch information
VadimPushtaev committed Dec 16, 2023
1 parent f6a4b49 commit d40c9a0
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 23 deletions.
39 changes: 17 additions & 22 deletions redis/connection.py
Expand Up @@ -92,34 +92,29 @@ def pack(self, *args):

buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))

def get_chunks(_args):
for _arg in map(self.encode, _args):
_arg_length = len(_arg)
yield SYM_EMPTY.join((SYM_DOLLAR, str(_arg_length).encode(), SYM_CRLF))
if isinstance(_arg, memoryview):
yield _arg # we yield it independently to avoid copying memoryview
yield SYM_CRLF
else:
yield SYM_EMPTY.join((_arg, SYM_CRLF))

buffer_cutoff = self._buffer_cutoff
for arg in map(self.encode, args):
for arg in get_chunks(args):
# to avoid large string mallocs, chunk the command into the
# output list if we're sending large values or memoryviews
arg_length = len(arg)
if (
len(buff) > buffer_cutoff
or arg_length > buffer_cutoff
or isinstance(arg, memoryview)
):
buff = SYM_EMPTY.join(
(buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
)
# output list if we're sending large values
if len(buff) + len(arg) > buffer_cutoff or isinstance(arg, memoryview):
output.append(buff)
buff = SYM_EMPTY
if isinstance(arg, memoryview):
output.append(arg)
buff = SYM_CRLF
else:
buff = SYM_EMPTY.join(
(
buff,
SYM_DOLLAR,
str(arg_length).encode(),
SYM_CRLF,
arg,
SYM_CRLF,
)
)
buff = SYM_EMPTY.join((buff, arg))
output.append(buff)

return output


Expand Down
53 changes: 52 additions & 1 deletion tests/test_connection.py
Expand Up @@ -6,7 +6,7 @@
import pytest
import redis
from redis import ConnectionPool, Redis
from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser
from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser, Encoder
from redis.backoff import NoBackoff
from redis.connection import (
Connection,
Expand Down Expand Up @@ -296,3 +296,54 @@ def mock_disconnect(_):

assert called == 1
pool.disconnect()


class TestPythonRespSerializer:
GET_A_ENCODED = b"*2\r\n$3\r\nGET\r\n$1\r\na\r\n"

def test_pack_buffer_cutoff_max(self):
encoder = Encoder("utf-8", "strict", False)
serializer = redis.connection.PythonRespSerializer(buffer_cutoff=1000, encode=encoder.encode)
packed = serializer.pack("GET", "a")
assert packed == [self.GET_A_ENCODED]

def test_pack_buffer_cutoff_average(self):
expected = [
b"*2\r\n$3\r\n",
b"GET\r\n",
b"$1\r\na\r\n",
]
assert b"".join(expected) == self.GET_A_ENCODED

encoder = Encoder("utf-8", "strict", False)
serializer = redis.connection.PythonRespSerializer(buffer_cutoff=8, encode=encoder.encode)
packed = serializer.pack("GET", "a")
assert packed == expected

def test_pack_buffer_cutoff_min(self):
expected = [
b"*2\r\n",
b"$3\r\n",
b"GET\r\n",
b"$1\r\n",
b"a\r\n",
]
assert b"".join(expected) == self.GET_A_ENCODED

encoder = Encoder("utf-8", "strict", False)
serializer = redis.connection.PythonRespSerializer(buffer_cutoff=1, encode=encoder.encode)
packed = serializer.pack("GET", "a")
assert packed == expected

def test_pack_memoryview(self):
expected = [
b"*2\r\n$3\r\n",
b"GET", # memoryview stays independently to avoid copying
b"\r\n$1\r\na\r\n",
]
assert b"".join(expected) == self.GET_A_ENCODED

encoder = Encoder("utf-8", "strict", False)
serializer = redis.connection.PythonRespSerializer(buffer_cutoff=1000, encode=encoder.encode)
packed = serializer.pack(memoryview(b"GET"), "a")
assert packed == expected

0 comments on commit d40c9a0

Please sign in to comment.