Skip to content

Commit

Permalink
Move MockStream and MockSocket into their own files
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Jan 4, 2023
1 parent 789cc94 commit 56dc082
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 95 deletions.
41 changes: 41 additions & 0 deletions tests/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Various mocks for testing


class MockSocket:
"""
A class simulating an readable socket, optionally raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

def recv(self, bufsize):
self.tick()
bufsize = min(5, bufsize) # truncate the read size
result = self.data[self.pos : self.pos + bufsize]
self.pos += len(result)
return result

def recv_into(self, buffer, nbytes=0, flags=0):
self.tick()
if nbytes == 0:
nbytes = len(buffer)
nbytes = min(5, nbytes) # truncate the read size
result = self.data[self.pos : self.pos + nbytes]
self.pos += len(result)
buffer[: len(result)] = result
return len(result)
51 changes: 51 additions & 0 deletions tests/test_asyncio/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio

# Helper Mocking classes for the tests.


class MockStream:
"""
A class simulating an asyncio input buffer, optionally raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

async def read(self, want):
self.tick()
want = 5
result = self.data[self.pos : self.pos + want]
self.pos += len(result)
return result

async def readline(self):
self.tick()
find = self.data.find(b"\n", self.pos)
if find >= 0:
result = self.data[self.pos : find + 1]
else:
result = self.data[self.pos :]
self.pos += len(result)
return result

async def readexactly(self, length):
self.tick()
result = self.data[self.pos : self.pos + length]
if len(result) < length:
raise asyncio.IncompleteReadError(result, None)
self.pos += len(result)
return result
55 changes: 4 additions & 51 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from tests.conftest import skip_if_server_version_lt

from .compat import mock
from .mocks import MockStream


@pytest.mark.onlynoncluster
async def test_invalid_response(create_redis):
r = await create_redis(single_connection_client=True)

raw = b"x"
fake_stream = FakeStream(raw + b"\r\n")
fake_stream = MockStream(raw + b"\r\n")

parser: BaseParser = r.connection._parser
with mock.patch.object(parser, "_stream", fake_stream):
Expand Down Expand Up @@ -119,54 +120,6 @@ async def test_connect_timeout_error_without_retry():
assert str(e.value) == "Timeout connecting to server"


class FakeStream:
"""
A class simulating an asyncio input buffer, but raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

async def read(self, want):
self.tick()
want = 5
result = self.data[self.pos : self.pos + want]
self.pos += len(result)
return result

async def readline(self):
self.tick()
find = self.data.find(b"\n", self.pos)
if find >= 0:
result = self.data[self.pos : find + 1]
else:
result = self.data[self.pos :]
self.pos += len(result)
return result

async def readexactly(self, length):
self.tick()
result = self.data[self.pos : self.pos + length]
if len(result) < length:
raise asyncio.IncompleteReadError(result, None)
self.pos += len(result)
return result


@pytest.mark.onlynoncluster
async def test_connection_parse_response_resume(r: redis.Redis):
"""
Expand All @@ -181,12 +134,12 @@ async def test_connection_parse_response_resume(r: redis.Redis):
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
)

conn._parser._stream = FakeStream(message, interrupt_every=2)
conn._parser._stream = MockStream(message, interrupt_every=2)
for i in range(100):
try:
response = await conn.read_response()
break
except FakeStream.TestError:
except MockStream.TestError:
pass

else:
Expand Down
49 changes: 5 additions & 44 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from redis.utils import HIREDIS_AVAILABLE

from .conftest import skip_if_server_version_lt
from .mocks import MockSocket


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
Expand Down Expand Up @@ -125,46 +126,6 @@ def test_connect_timeout_error_without_retry(self):
self.clear(conn)


class FakeSocket:
"""
A class simulating an readable socket, but raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

def recv(self, bufsize):
self.tick()
bufsize = min(5, bufsize) # truncate the read size
result = self.data[self.pos : self.pos + bufsize]
self.pos += len(result)
return result

def recv_into(self, buffer, nbytes=0, flags=0):
self.tick()
if nbytes == 0:
nbytes = len(buffer)
nbytes = min(5, nbytes) # truncate the read size
result = self.data[self.pos : self.pos + nbytes]
self.pos += len(result)
buffer[: len(result)] = result
return len(result)


@pytest.mark.onlynoncluster
@pytest.mark.parametrize(
"parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"]
Expand All @@ -185,17 +146,17 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class):
b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
)
fake_socket = FakeSocket(message, interrupt_every=2)
mock_socket = MockSocket(message, interrupt_every=2)

if isinstance(conn._parser, PythonParser):
conn._parser._buffer._sock = fake_socket
conn._parser._buffer._sock = mock_socket
else:
conn._parser._sock = fake_socket
conn._parser._sock = mock_socket
for i in range(100):
try:
response = conn.read_response()
break
except FakeSocket.TestError:
except MockSocket.TestError:
pass

else:
Expand Down

0 comments on commit 56dc082

Please sign in to comment.