From a55cdb1f9a1710ca63c2eddee90cf7b2ef2ba797 Mon Sep 17 00:00:00 2001 From: Gabriel Erzse Date: Thu, 25 Apr 2024 15:20:55 +0300 Subject: [PATCH] Streamline client side caching API typing Streamline the typing of the client side caching API. Some of the methods are defining commands of type `str`, while in reality tuples are being sent for those parameters. Add client side cache tests for Sentinels. In order to make this work, fix the sentinel configuration in the docker-compose stack. Add a test for client side caching with a truly custom cache, not just injecting our internal cache structure as custom. Add a test for client side caching where two different types of commands use the same key, to make sure they invalidate each others cached data. --- dev_requirements.txt | 1 + dockers/sentinel.conf | 3 +- redis/_cache.py | 54 ++++++++----- tests/conftest.py | 42 ++++++++++ tests/test_asyncio/conftest.py | 29 +++++++ tests/test_asyncio/test_cache.py | 45 +++++++++++ tests/test_cache.py | 130 ++++++++++++++++++++++++++++++- 7 files changed, 282 insertions(+), 22 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index ef3b1aa22d..48ec278d83 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,6 @@ click==8.0.4 black==24.3.0 +cachetools flake8==5.0.4 flake8-isort==6.0.0 flynt~=0.69.0 diff --git a/dockers/sentinel.conf b/dockers/sentinel.conf index 1a33f53344..75f711e5d4 100644 --- a/dockers/sentinel.conf +++ b/dockers/sentinel.conf @@ -1,4 +1,5 @@ -sentinel monitor redis-py-test 127.0.0.1 6379 2 +sentinel resolve-hostnames yes +sentinel monitor redis-py-test redis 6379 2 sentinel down-after-milliseconds redis-py-test 5000 sentinel failover-timeout redis-py-test 60000 sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/redis/_cache.py b/redis/_cache.py index 7acfdde3e7..20e0823d17 100644 --- a/redis/_cache.py +++ b/redis/_cache.py @@ -4,13 +4,12 @@ from abc import ABC, abstractmethod from collections import OrderedDict, defaultdict from enum import Enum -from typing import List +from typing import Iterable, List, Union from redis.typing import KeyT, ResponseT DEFAULT_EVICTION_POLICY = "lru" - DEFAULT_BLACKLIST = [ "BF.CARD", "BF.DEBUG", @@ -71,7 +70,6 @@ "TTL", ] - DEFAULT_WHITELIST = [ "BITCOUNT", "BITFIELD_RO", @@ -215,7 +213,6 @@ def __init__( max_size: int = 10000, ttl: int = 0, eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, - **kwargs, ): self.max_size = max_size self.ttl = ttl @@ -224,12 +221,17 @@ def __init__( self.key_commands_map = defaultdict(set) self.commands_ttl_list = [] - def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): + def set( + self, + command: Union[str, Iterable[str]], + response: ResponseT, + keys_in_command: List[KeyT], + ): """ Set a redis command and its response in the cache. Args: - command (str): The redis command. + command (Union[str, Iterable[str]]): The redis command. response (ResponseT): The response associated with the command. keys_in_command (List[KeyT]): The list of keys used in the command. """ @@ -244,12 +246,12 @@ def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): self._update_key_commands_map(keys_in_command, command) self.commands_ttl_list.append(command) - def get(self, command: str) -> ResponseT: + def get(self, command: Union[str, Iterable[str]]) -> ResponseT: """ Get the response for a redis command from the cache. Args: - command (str): The redis command. + command (Union[str, Iterable[str]]): The redis command. Returns: ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa @@ -261,12 +263,12 @@ def get(self, command: str) -> ResponseT: self._update_access(command) return copy.deepcopy(self.cache[command]["response"]) - def delete_command(self, command: str): + def delete_command(self, command: Union[str, Iterable[str]]): """ Delete a redis command and its metadata from the cache. Args: - command (str): The redis command to be deleted. + command (Union[str, Iterable[str]]): The redis command to be deleted. """ if command in self.cache: keys_in_command = self.cache[command].get("keys") @@ -274,8 +276,16 @@ def delete_command(self, command: str): self.commands_ttl_list.remove(command) del self.cache[command] - def delete_many(self, commands): - pass + def delete_many(self, commands: List[Union[str, Iterable[str]]]): + """ + Delete multiple commands and their metadata from the cache. + + Args: + commands (List[Union[str, Iterable[str]]]): The list of commands to be + deleted. + """ + for command in commands: + self.delete_command(command) def flush(self): """Clear the entire cache, removing all redis commands and metadata.""" @@ -283,12 +293,12 @@ def flush(self): self.key_commands_map.clear() self.commands_ttl_list = [] - def _is_expired(self, command: str) -> bool: + def _is_expired(self, command: Union[str, Iterable[str]]) -> bool: """ Check if a redis command has expired based on its time-to-live. Args: - command (str): The redis command. + command (Union[str, Iterable[str]]): The redis command. Returns: bool: True if the command has expired, False otherwise. @@ -297,12 +307,12 @@ def _is_expired(self, command: str) -> bool: return False return time.monotonic() - self.cache[command]["ctime"] > self.ttl - def _update_access(self, command: str): + def _update_access(self, command: Union[str, Iterable[str]]): """ Update the access information for a redis command based on the eviction policy. Args: - command (str): The redis command. + command (Union[str, Iterable[str]]): The redis command. """ if self.eviction_policy == EvictionPolicy.LRU.value: self.cache.move_to_end(command) @@ -329,24 +339,28 @@ def _evict(self): random_command = random.choice(list(self.cache.keys())) self.cache.pop(random_command) - def _update_key_commands_map(self, keys: List[KeyT], command: str): + def _update_key_commands_map( + self, keys: List[KeyT], command: Union[str, Iterable[str]] + ): """ Update the key_commands_map with command that uses the keys. Args: keys (List[KeyT]): The list of keys used in the command. - command (str): The redis command. + command (Union[str, Iterable[str]]): The redis command. """ for key in keys: self.key_commands_map[key].add(command) - def _del_key_commands_map(self, keys: List[KeyT], command: str): + def _del_key_commands_map( + self, keys: List[KeyT], command: Union[str, Iterable[str]] + ): """ Remove a redis command from the key_commands_map. Args: keys (List[KeyT]): The list of keys used in the redis command. - command (str): The redis command. + command (Union[str, Iterable[str]]): The redis command. """ for key in keys: self.key_commands_map[key].remove(command) diff --git a/tests/conftest.py b/tests/conftest.py index 8786e2b9f0..e783b6e8f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import pytest import redis from packaging.version import Version +from redis import Sentinel from redis.backoff import NoBackoff from redis.connection import Connection, parse_url from redis.exceptions import RedisClusterException @@ -105,6 +106,19 @@ def pytest_addoption(parser): "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop" ) + parser.addoption( + "--sentinels", + action="store", + default="localhost:26379,localhost:26380,localhost:26381", + help="Comma-separated list of sentinel IPs and ports", + ) + parser.addoption( + "--master-service", + action="store", + default="redis-py-test", + help="Name of the Redis master service that the sentinels are monitoring", + ) + def _get_info(redis_url): client = redis.Redis.from_url(redis_url) @@ -352,6 +366,34 @@ def sslclient(request): yield client +@pytest.fixture() +def sentinel_setup(local_cache, request): + sentinel_ips = request.config.getoption("--sentinels") + sentinel_endpoints = [ + (ip.strip(), int(port.strip())) + for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) + ] + kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} + sentinel = Sentinel( + sentinel_endpoints, + socket_timeout=0.1, + client_cache=local_cache, + protocol=3, + **kwargs, + ) + yield sentinel + for s in sentinel.sentinels: + s.close() + + +@pytest.fixture() +def master(request, sentinel_setup): + master_service = request.config.getoption("--master-service") + master = sentinel_setup.master_for(master_service) + yield master + master.close() + + def _gen_cluster_mock_resp(r, response): connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index c6afec5af6..1216b2edf9 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -7,6 +7,7 @@ import redis.asyncio as redis from packaging.version import Version from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser +from redis.asyncio import Sentinel from redis.asyncio.client import Monitor from redis.asyncio.connection import Connection, parse_url from redis.asyncio.retry import Retry @@ -136,6 +137,34 @@ async def decoded_r(create_redis): return await create_redis(decode_responses=True) +@pytest_asyncio.fixture() +async def sentinel_setup(local_cache, request): + sentinel_ips = request.config.getoption("--sentinels") + sentinel_endpoints = [ + (ip.strip(), int(port.strip())) + for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) + ] + kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} + sentinel = Sentinel( + sentinel_endpoints, + socket_timeout=0.1, + client_cache=local_cache, + protocol=3, + **kwargs, + ) + yield sentinel + for s in sentinel.sentinels: + await s.close() + + +@pytest_asyncio.fixture() +async def master(request, sentinel_setup): + master_service = request.config.getoption("--master-service") + master = sentinel_setup.master_for(master_service) + yield master + await master.close() + + def _gen_cluster_mock_resp(r, response): connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py index 4762bb7c05..08227fb61e 100644 --- a/tests/test_asyncio/test_cache.py +++ b/tests/test_asyncio/test_cache.py @@ -14,6 +14,11 @@ async def r(request, create_redis): yield r, cache +@pytest_asyncio.fixture() +async def local_cache(): + yield _LocalCache() + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") class TestLocalCache: @pytest.mark.onlynoncluster @@ -228,3 +233,43 @@ async def test_cache_decode_response(self, r): assert cache.get(("GET", "foo")) is None # get key from redis assert await r.get("foo") == "barbar" + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +class TestSentinelLocalCache: + + async def test_get_from_cache(self, local_cache, master): + await master.set("foo", "bar") + # get key from redis and save in local cache + assert await master.get("foo") == b"bar" + # get key from local cache + assert local_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + await master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await master.ping() + # the command is not in the local cache anymore + assert local_cache.get(("GET", "foo")) is None + # get key from redis + assert await master.get("foo") == b"barbar" + + @pytest.mark.parametrize( + "sentinel_setup", + [{"kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_cache_decode_response(self, local_cache, sentinel_setup, master): + await master.set("foo", "bar") + # get key from redis and save in local cache + assert await master.get("foo") == "bar" + # get key from local cache + assert local_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + await master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await master.ping() + # the command is not in the local cache anymore + assert local_cache.get(("GET", "foo")) is None + # get key from redis + assert await master.get("foo") == "barbar" diff --git a/tests/test_cache.py b/tests/test_cache.py index dd33afd23e..e3e78e293c 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,8 +1,12 @@ import time +from collections import defaultdict +from typing import List +import cachetools import pytest import redis -from redis._cache import _LocalCache +from redis._cache import AbstractCache, _LocalCache +from redis.typing import KeyT, ResponseT from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client @@ -18,6 +22,11 @@ def r(request): # client.flushdb() +@pytest.fixture() +def local_cache(): + return _LocalCache() + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") class TestLocalCache: @pytest.mark.onlynoncluster @@ -189,6 +198,27 @@ def test_csc_not_cause_disconnects(self, r): id4 = r.client_id() assert id1 == id2 == id3 == id4 + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_multiple_commands_same_key(self, r): + r, cache = r + r.mset({"a": 1, "b": 1}) + assert r.mget("a", "b") == ["1", "1"] + # value should be in local cache + assert cache.get(("MGET", "a", "b")) == ["1", "1"] + # set only one key + r.set("a", 2) + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert cache.get(("MGET", "a", "b")) is None + # get from redis + assert r.mget("a", "b") == ["2", "1"] + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster @@ -233,3 +263,101 @@ def test_cache_decode_response(self, r): assert cache.get(("GET", "foo")) is None # get key from redis assert r.get("foo") == "barbar" + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +class TestSentinelLocalCache: + + def test_get_from_cache(self, local_cache, master): + master.set("foo", "bar") + # get key from redis and save in local cache + assert master.get("foo") == b"bar" + # get key from local cache + assert local_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + master.ping() + # the command is not in the local cache anymore + assert local_cache.get(("GET", "foo")) is None + # get key from redis + assert master.get("foo") == b"barbar" + + @pytest.mark.parametrize( + "sentinel_setup", + [{"kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_cache_decode_response(self, local_cache, sentinel_setup, master): + master.set("foo", "bar") + # get key from redis and save in local cache + assert master.get("foo") == "bar" + # get key from local cache + assert local_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + master.ping() + # the command is not in the local cache anymore + assert local_cache.get(("GET", "foo")) is None + # get key from redis + assert master.get("foo") == "barbar" + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +class TestCustomCache: + class _CustomCache(AbstractCache): + def __init__(self): + self.responses = cachetools.LRUCache(maxsize=1000) + self.keys_to_commands = defaultdict(list) + self.commands_to_keys = defaultdict(list) + + def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): + self.responses[command] = response + for key in keys_in_command: + self.keys_to_commands[key].append(tuple(command)) + self.commands_to_keys[command].append(tuple(keys_in_command)) + + def get(self, command: str) -> ResponseT: + return self.responses.get(command) + + def delete_command(self, command: str): + self.responses.pop(command, None) + keys = self.commands_to_keys.pop(command, []) + for key in keys: + if command in self.keys_to_commands[key]: + self.keys_to_commands[key].remove(command) + + def delete_many(self, commands): + for command in commands: + self.delete_command(command) + + def flush(self): + self.responses.clear() + self.commands_to_keys.clear() + self.keys_to_commands.clear() + + def invalidate_key(self, key: KeyT): + commands = self.keys_to_commands.pop(key, []) + for command in commands: + self.delete_command(command) + + @pytest.mark.parametrize("r", [{"cache": _CustomCache()}], indirect=True) + def test_get_from_cache(self, r, r2): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == b"barbar"