From d529c2ad8d2cf4dcfb41bfd93ea68cfefd81aa66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 22 Feb 2024 12:48:00 +0000 Subject: [PATCH 1/5] Fix incorrect asserts in test and ensure connections are closed (#3004) --- tests/test_ssl.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 465fdabb89..dfd8837262 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -26,13 +26,15 @@ def test_ssl_with_invalid_cert(self, request): sslclient = redis.from_url(ssl_url) with pytest.raises(ConnectionError) as e: sslclient.ping() - assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e) + assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e) + sslclient.close() def test_ssl_connection(self, request): ssl_url = request.config.option.redis_ssl_url p = urlparse(ssl_url)[1].split(":") r = redis.Redis(host=p[0], port=p[1], ssl=True, ssl_cert_reqs="none") assert r.ping() + r.close() def test_ssl_connection_without_ssl(self, request): ssl_url = request.config.option.redis_ssl_url @@ -41,7 +43,8 @@ def test_ssl_connection_without_ssl(self, request): with pytest.raises(ConnectionError) as e: r.ping() - assert "Connection closed by server" in str(e) + assert "Connection closed by server" in str(e) + r.close() def test_validating_self_signed_certificate(self, request): ssl_url = request.config.option.redis_ssl_url @@ -56,6 +59,7 @@ def test_validating_self_signed_certificate(self, request): ssl_ca_certs=self.SERVER_CERT, ) assert r.ping() + r.close() def test_validating_self_signed_string_certificate(self, request): with open(self.SERVER_CERT) as f: @@ -72,6 +76,7 @@ def test_validating_self_signed_string_certificate(self, request): ssl_ca_data=cert_data, ) assert r.ping() + r.close() def _create_oscp_conn(self, request): ssl_url = request.config.option.redis_ssl_url @@ -92,22 +97,25 @@ def _create_oscp_conn(self, request): def test_ssl_ocsp_called(self, request): r = self._create_oscp_conn(request) with pytest.raises(RedisError) as e: - assert r.ping() - assert "cryptography not installed" in str(e) + r.ping() + assert "cryptography is not installed" in str(e) + r.close() @skip_if_nocryptography() def test_ssl_ocsp_called_withcrypto(self, request): r = self._create_oscp_conn(request) with pytest.raises(ConnectionError) as e: assert r.ping() - assert "No AIA information present in ssl certificate" in str(e) + assert "No AIA information present in ssl certificate" in str(e) + r.close() # rediss://, url based ssl_url = request.config.option.redis_ssl_url sslclient = redis.from_url(ssl_url) with pytest.raises(ConnectionError) as e: sslclient.ping() - assert "No AIA information present in ssl certificate" in str(e) + assert "No AIA information present in ssl certificate" in str(e) + sslclient.close() @skip_if_nocryptography() def test_valid_ocsp_cert_http(self): @@ -132,7 +140,7 @@ def test_revoked_ocsp_certificate(self): ocsp = OCSPVerifier(wrapped, hostname, 443) with pytest.raises(ConnectionError) as e: assert ocsp.is_valid() - assert "REVOKED" in str(e) + assert "REVOKED" in str(e) @skip_if_nocryptography() def test_unauthorized_ocsp(self): @@ -157,7 +165,7 @@ def test_ocsp_not_present_in_response(self): ocsp = OCSPVerifier(wrapped, hostname, 443) with pytest.raises(ConnectionError) as e: assert ocsp.is_valid() - assert "from the" in str(e) + assert "from the" in str(e) @skip_if_nocryptography() def test_unauthorized_then_direct(self): @@ -193,6 +201,7 @@ def test_mock_ocsp_staple(self, request): with pytest.raises(RedisError): r.ping() + r.close() ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) ctx.use_certificate_file(self.SERVER_CERT) @@ -213,7 +222,8 @@ def test_mock_ocsp_staple(self, request): with pytest.raises(ConnectionError) as e: r.ping() - assert "no ocsp response present" in str(e) + assert "no ocsp response present" in str(e) + r.close() r = redis.Redis( host=p[0], @@ -228,4 +238,5 @@ def test_mock_ocsp_staple(self, request): with pytest.raises(ConnectionError) as e: r.ping() - assert "no ocsp response present" in str(e) + assert "no ocsp response present" in str(e) + r.close() From c573bc4ab61d0d57726f872fdfca31962d44b534 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:31:59 +0200 Subject: [PATCH 2/5] Fix bug: client side caching causes unexpected disconnections (#3160) * fix disconnects * skip test in cluster --------- Co-authored-by: Chayim --- redis/_parsers/resp3.py | 4 +++- redis/client.py | 14 +++++++------- redis/commands/core.py | 2 +- redis/connection.py | 17 +++++++--------- tests/test_cache.py | 43 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 19 deletions(-) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 13aa1ffccb..88c8d5e52b 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -117,7 +117,9 @@ def _read_response(self, disable_decoding=False, push_request=False): ) for _ in range(int(response)) ] - self.handle_push_response(response, disable_decoding, push_request) + response = self.handle_push_response( + response, disable_decoding, push_request + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") diff --git a/redis/client.py b/redis/client.py index 85ed7380a8..79f52cc989 100755 --- a/redis/client.py +++ b/redis/client.py @@ -563,10 +563,10 @@ def execute_command(self, *args, **options): pool = self.connection_pool conn = self.connection or pool.get_connection(command_name, **options) response_from_cache = conn._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - try: + try: + if response_from_cache is not None: + return response_from_cache + else: response = conn.retry.call_with_retry( lambda: self._send_command_parse_response( conn, command_name, *args, **options @@ -575,9 +575,9 @@ def execute_command(self, *args, **options): ) conn._add_to_local_cache(args, response, keys) return response - finally: - if not self.connection: - pool.release(conn) + finally: + if not self.connection: + pool.release(conn) def parse_response(self, connection, command_name, **options): """Parses a response from the Redis server""" diff --git a/redis/commands/core.py b/redis/commands/core.py index 6d81d76035..464e8d8c85 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -2011,7 +2011,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options = {} if not args: options[EMPTY_RESPONSE] = [] - options["keys"] = keys + options["keys"] = args return self.execute_command("MGET", *args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: diff --git a/redis/connection.py b/redis/connection.py index 617d04af5c..b89ce0e94b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,6 +1,5 @@ import copy import os -import select import socket import ssl import sys @@ -609,11 +608,6 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def _socket_is_empty(self): - """Check if the socket is empty""" - r, _, _ = select.select([self._sock], [], [], 0) - return not bool(r) - def _cache_invalidation_process( self, data: List[Union[str, Optional[List[str]]]] ) -> None: @@ -639,7 +633,7 @@ def _get_from_local_cache(self, command: str): or command[0] not in self.cache_whitelist ): return None - while not self._socket_is_empty(): + while self.can_read(): self.read_response(push_request=True) return self.client_cache.get(command) @@ -1187,12 +1181,15 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": try: # ensure this connection is connected to Redis connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the + # if client caching is not enabled connections that the pool + # provides should be ready to send a command. + # if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. + # (if caching enabled the connection will not always be ready + # to send a command because it may contain invalidation messages) try: - if connection.can_read(): + if connection.can_read() and connection.client_cache is None: raise ConnectionError("Connection has data") except (ConnectionError, OSError): connection.disconnect() diff --git a/tests/test_cache.py b/tests/test_cache.py index 4eb5160ecc..dd33afd23e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -146,6 +146,49 @@ def test_cache_return_copy(self, r): check = cache.get(("LRANGE", "mylist", 0, -1)) assert check == [b"baz", b"bar", b"foo"] + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_csc_not_cause_disconnects(self, r): + r, cache = r + id1 = r.client_id() + r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1}) + assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] + id2 = r.client_id() + + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] + assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [ + "1", + "1", + "1", + "1", + "1", + "1", + ] + + r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2}) + id3 = r.client_id() + # client should get value from redis server post invalidate messages + assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"] + + r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3}) + # need to check that we get correct value 3 and not 2 + assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] + + r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4}) + # need to check that we get correct value 4 and not 3 + assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] + id4 = r.client_id() + assert id1 == id2 == id3 == id4 + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster From 26ab964ec18ec255672abaec90de439705151b5c Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:59:31 +0200 Subject: [PATCH 3/5] Fix bug: client side caching causes unexpected disconnections (async version) (#3165) * fix disconnects * skip test in cluster * add test * save return value from handle_push_response (without it 'read_response' return the push message) * insert return response from cache to the try block to prevent connection leak * enable to get connection with data avaliable to read in csc mode and change can_read_destructive to not read data * fix check if socket is empty (at_eof() can return False but this doesn't mean there's definitely more data to read) --------- Co-authored-by: Chayim --- redis/_parsers/base.py | 2 +- redis/_parsers/resp3.py | 4 ++- redis/asyncio/client.py | 40 +++++++++++++++--------------- redis/asyncio/connection.py | 14 ++++++++--- tests/test_asyncio/test_cache.py | 42 ++++++++++++++++++++++++++++++++ 5 files changed, 77 insertions(+), 25 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 8e59249bef..0137539d66 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -182,7 +182,7 @@ async def can_read_destructive(self) -> bool: return True try: async with async_timeout(0): - return await self._stream.read(1) + return self._stream.at_eof() except TimeoutError: return False diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 88c8d5e52b..7afa43a0c2 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -261,7 +261,9 @@ async def _read_response( ) for _ in range(int(response)) ] - await self.handle_push_response(response, disable_decoding, push_request) + response = await self.handle_push_response( + response, disable_decoding, push_request + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3e2912bfca..9ff2e3917f 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -629,25 +629,27 @@ async def execute_command(self, *args, **options): pool = self.connection_pool conn = self.connection or await pool.get_connection(command_name, **options) response_from_cache = await conn._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - if self.single_connection_client: - await self._single_conn_lock.acquire() - try: - response = await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - conn._add_to_local_cache(args, response, keys) - return response - finally: - if self.single_connection_client: - self._single_conn_lock.release() - if not self.connection: - await pool.release(conn) + try: + if response_from_cache is not None: + return response_from_cache + else: + try: + if self.single_connection_client: + await self._single_conn_lock.acquire() + response = await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + conn._add_to_local_cache(args, response, keys) + return response + finally: + if self.single_connection_client: + self._single_conn_lock.release() + finally: + if not self.connection: + await pool.release(conn) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 81df3b3543..6c5c58c683 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -685,7 +685,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] def _socket_is_empty(self): """Check if the socket is empty""" - return not self._reader.at_eof() + return len(self._reader._buffer) == 0 def _cache_invalidation_process( self, data: List[Union[str, Optional[List[str]]]] @@ -1192,12 +1192,18 @@ def make_connection(self): async def ensure_connection(self, connection: AbstractConnection): """Ensure that the connection object is connected and valid""" await connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the + # if client caching is not enabled connections that the pool + # provides should be ready to send a command. + # if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. + # (if caching enabled the connection will not always be ready + # to send a command because it may contain invalidation messages) try: - if await connection.can_read_destructive(): + if ( + await connection.can_read_destructive() + and connection.client_cache is None + ): raise ConnectionError("Connection has data") from None except (ConnectionError, OSError): await connection.disconnect() diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py index bf20337dfb..4762bb7c05 100644 --- a/tests/test_asyncio/test_cache.py +++ b/tests/test_asyncio/test_cache.py @@ -142,6 +142,48 @@ async def test_cache_return_copy(self, r): check = cache.get(("LRANGE", "mylist", 0, -1)) assert check == [b"baz", b"bar", b"foo"] + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_csc_not_cause_disconnects(self, r): + r, cache = r + id1 = await r.client_id() + await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}) + assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] + id2 = await r.client_id() + + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] + assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [ + "1", + "1", + "1", + "1", + "1", + ] + + await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2}) + id3 = await r.client_id() + # client should get value from redis server post invalidate messages + assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"] + + await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3}) + # need to check that we get correct value 3 and not 2 + assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] + + await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4}) + # need to check that we get correct value 4 and not 3 + assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] + id4 = await r.client_id() + assert id1 == id2 == id3 == id4 + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster From 9df2225ba6309d50742959328958755210d757bd Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 29 Feb 2024 13:07:37 +0200 Subject: [PATCH 4/5] Version 5.1.0b4 (#3166) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c6a9e205f5..68bfc25c42 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.1.0b3", + version="5.1.0b4", packages=find_packages( include=[ "redis", From 9ad1546c06bb7321e7e19bbb9c8dd758343c0390 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:12:48 +0200 Subject: [PATCH 5/5] Fix lock error (#3176) --- redis/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/exceptions.py b/redis/exceptions.py index ddb4041da3..8af58cb0db 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -82,7 +82,7 @@ class LockError(RedisError, ValueError): # NOTE: For backwards compatibility, this class derives from ValueError. # This was originally chosen to behave like threading.Lock. - def __init__(self, message, lock_name=None): + def __init__(self, message=None, lock_name=None): self.message = message self.lock_name = lock_name