Skip to content

Commit

Permalink
Merge branch 'master' into clustersupport
Browse files Browse the repository at this point in the history
  • Loading branch information
ramwin committed Mar 8, 2024
2 parents cbadd10 + 9ad1546 commit 871630f
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 56 deletions.
2 changes: 1 addition & 1 deletion redis/_parsers/base.py
Expand Up @@ -182,7 +182,7 @@ async def can_read_destructive(self) -> bool:
return True
try:
async with async_timeout(0):
return await self._stream.read(1)
return self._stream.at_eof()
except TimeoutError:
return False

Expand Down
8 changes: 6 additions & 2 deletions redis/_parsers/resp3.py
Expand Up @@ -117,7 +117,9 @@ def _read_response(self, disable_decoding=False, push_request=False):
)
for _ in range(int(response))
]
self.handle_push_response(response, disable_decoding, push_request)
response = self.handle_push_response(
response, disable_decoding, push_request
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

Expand Down Expand Up @@ -259,7 +261,9 @@ async def _read_response(
)
for _ in range(int(response))
]
await self.handle_push_response(response, disable_decoding, push_request)
response = await self.handle_push_response(
response, disable_decoding, push_request
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

Expand Down
40 changes: 21 additions & 19 deletions redis/asyncio/client.py
Expand Up @@ -629,25 +629,27 @@ async def execute_command(self, *args, **options):
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)
response_from_cache = await conn._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
try:
if response_from_cache is not None:
return response_from_cache
else:
try:
if self.single_connection_client:
await self._single_conn_lock.acquire()
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
self._single_conn_lock.release()
finally:
if not self.connection:
await pool.release(conn)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down
14 changes: 10 additions & 4 deletions redis/asyncio/connection.py
Expand Up @@ -685,7 +685,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]

def _socket_is_empty(self):
"""Check if the socket is empty"""
return not self._reader.at_eof()
return len(self._reader._buffer) == 0

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
Expand Down Expand Up @@ -1192,12 +1192,18 @@ def make_connection(self):
async def ensure_connection(self, connection: AbstractConnection):
"""Ensure that the connection object is connected and valid"""
await connection.connect()
# connections that the pool provides should be ready to send
# a command. if not, the connection was either returned to the
# if client caching is not enabled connections that the pool
# provides should be ready to send a command.
# if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
# (if caching enabled the connection will not always be ready
# to send a command because it may contain invalidation messages)
try:
if await connection.can_read_destructive():
if (
await connection.can_read_destructive()
and connection.client_cache is None
):
raise ConnectionError("Connection has data") from None
except (ConnectionError, OSError):
await connection.disconnect()
Expand Down
14 changes: 7 additions & 7 deletions redis/client.py
Expand Up @@ -567,10 +567,10 @@ def execute_command(self, *args, **options):
pool = self.connection_pool
conn = self.connection or pool.get_connection(command_name, **options)
response_from_cache = conn._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
try:
try:
if response_from_cache is not None:
return response_from_cache
else:
response = conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
Expand All @@ -579,9 +579,9 @@ def execute_command(self, *args, **options):
)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if not self.connection:
pool.release(conn)
finally:
if not self.connection:
pool.release(conn)

def parse_response(self, connection, command_name, **options):
"""Parses a response from the Redis server"""
Expand Down
2 changes: 1 addition & 1 deletion redis/commands/core.py
Expand Up @@ -2011,7 +2011,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT:
options = {}
if not args:
options[EMPTY_RESPONSE] = []
options["keys"] = keys
options["keys"] = args
return self.execute_command("MGET", *args, **options)

def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT:
Expand Down
17 changes: 7 additions & 10 deletions redis/connection.py
@@ -1,6 +1,5 @@
import copy
import os
import select
import socket
import ssl
import sys
Expand Down Expand Up @@ -609,11 +608,6 @@ def pack_commands(self, commands):
output.append(SYM_EMPTY.join(pieces))
return output

def _socket_is_empty(self):
"""Check if the socket is empty"""
r, _, _ = select.select([self._sock], [], [], 0)
return not bool(r)

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
) -> None:
Expand All @@ -639,7 +633,7 @@ def _get_from_local_cache(self, command: str):
or command[0] not in self.cache_whitelist
):
return None
while not self._socket_is_empty():
while self.can_read():
self.read_response(push_request=True)
return self.client_cache.get(command)

Expand Down Expand Up @@ -1189,12 +1183,15 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection":
try:
# ensure this connection is connected to Redis
connection.connect()
# connections that the pool provides should be ready to send
# a command. if not, the connection was either returned to the
# if client caching is not enabled connections that the pool
# provides should be ready to send a command.
# if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
# (if caching enabled the connection will not always be ready
# to send a command because it may contain invalidation messages)
try:
if connection.can_read():
if connection.can_read() and connection.client_cache is None:
raise ConnectionError("Connection has data")
except (ConnectionError, OSError):
connection.disconnect()
Expand Down
2 changes: 1 addition & 1 deletion redis/exceptions.py
Expand Up @@ -82,7 +82,7 @@ class LockError(RedisError, ValueError):
# NOTE: For backwards compatibility, this class derives from ValueError.
# This was originally chosen to behave like threading.Lock.

def __init__(self, message, lock_name=None):
def __init__(self, message=None, lock_name=None):
self.message = message
self.lock_name = lock_name

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -8,7 +8,7 @@
long_description_content_type="text/markdown",
keywords=["Redis", "key-value store", "database"],
license="MIT",
version="5.1.0b3",
version="5.1.0b4",
packages=find_packages(
include=[
"redis",
Expand Down
42 changes: 42 additions & 0 deletions tests/test_asyncio/test_cache.py
Expand Up @@ -142,6 +142,48 @@ async def test_cache_return_copy(self, r):
check = cache.get(("LRANGE", "mylist", 0, -1))
assert check == [b"baz", b"bar", b"foo"]

@pytest.mark.onlynoncluster
@pytest.mark.parametrize(
"r",
[{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
indirect=True,
)
async def test_csc_not_cause_disconnects(self, r):
r, cache = r
id1 = await r.client_id()
await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1})
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
id2 = await r.client_id()

# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [
"1",
"1",
"1",
"1",
"1",
]

await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2})
id3 = await r.client_id()
# client should get value from redis server post invalidate messages
assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"]

await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3})
# need to check that we get correct value 3 and not 2
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]

await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4})
# need to check that we get correct value 4 and not 3
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
id4 = await r.client_id()
assert id1 == id2 == id3 == id4


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@pytest.mark.onlycluster
Expand Down
43 changes: 43 additions & 0 deletions tests/test_cache.py
Expand Up @@ -146,6 +146,49 @@ def test_cache_return_copy(self, r):
check = cache.get(("LRANGE", "mylist", 0, -1))
assert check == [b"baz", b"bar", b"foo"]

@pytest.mark.onlynoncluster
@pytest.mark.parametrize(
"r",
[{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
indirect=True,
)
def test_csc_not_cause_disconnects(self, r):
r, cache = r
id1 = r.client_id()
r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1})
assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"]
id2 = r.client_id()

# client should get value from client cache
assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"]
assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [
"1",
"1",
"1",
"1",
"1",
"1",
]

r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2})
id3 = r.client_id()
# client should get value from redis server post invalidate messages
assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"]

r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3})
# need to check that we get correct value 3 and not 2
assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"]
# client should get value from client cache
assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"]

r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4})
# need to check that we get correct value 4 and not 3
assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"]
# client should get value from client cache
assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"]
id4 = r.client_id()
assert id1 == id2 == id3 == id4


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@pytest.mark.onlycluster
Expand Down

0 comments on commit 871630f

Please sign in to comment.