Skip to content

Commit

Permalink
Streamline client side caching API typing
Browse files Browse the repository at this point in the history
Streamline the typing of the client side caching API. Some of the
methods are defining commands of type `str`, while in reality tuples
are being sent for those parameters.

Add client side cache tests for Sentinels. In order to make this work,
fix the sentinel configuration in the docker-compose stack.

Add a test for client side caching with a truly custom cache, not just
injecting our internal cache structure as custom.

Add a test for client side caching where two different types of commands
use the same key, to make sure they invalidate each others cached data.
  • Loading branch information
Gabriel Erzse committed Apr 25, 2024
1 parent 07fc339 commit a55cdb1
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 22 deletions.
1 change: 1 addition & 0 deletions dev_requirements.txt
@@ -1,5 +1,6 @@
click==8.0.4
black==24.3.0
cachetools
flake8==5.0.4
flake8-isort==6.0.0
flynt~=0.69.0
Expand Down
3 changes: 2 additions & 1 deletion dockers/sentinel.conf
@@ -1,4 +1,5 @@
sentinel monitor redis-py-test 127.0.0.1 6379 2
sentinel resolve-hostnames yes
sentinel monitor redis-py-test redis 6379 2
sentinel down-after-milliseconds redis-py-test 5000
sentinel failover-timeout redis-py-test 60000
sentinel parallel-syncs redis-py-test 1
54 changes: 34 additions & 20 deletions redis/_cache.py
Expand Up @@ -4,13 +4,12 @@
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import List
from typing import Iterable, List, Union

from redis.typing import KeyT, ResponseT

DEFAULT_EVICTION_POLICY = "lru"


DEFAULT_BLACKLIST = [
"BF.CARD",
"BF.DEBUG",
Expand Down Expand Up @@ -71,7 +70,6 @@
"TTL",
]


DEFAULT_WHITELIST = [
"BITCOUNT",
"BITFIELD_RO",
Expand Down Expand Up @@ -215,7 +213,6 @@ def __init__(
max_size: int = 10000,
ttl: int = 0,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
**kwargs,
):
self.max_size = max_size
self.ttl = ttl
Expand All @@ -224,12 +221,17 @@ def __init__(
self.key_commands_map = defaultdict(set)
self.commands_ttl_list = []

def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
def set(
self,
command: Union[str, Iterable[str]],
response: ResponseT,
keys_in_command: List[KeyT],
):
"""
Set a redis command and its response in the cache.
Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
response (ResponseT): The response associated with the command.
keys_in_command (List[KeyT]): The list of keys used in the command.
"""
Expand All @@ -244,12 +246,12 @@ def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
self._update_key_commands_map(keys_in_command, command)
self.commands_ttl_list.append(command)

def get(self, command: str) -> ResponseT:
def get(self, command: Union[str, Iterable[str]]) -> ResponseT:
"""
Get the response for a redis command from the cache.
Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
Returns:
ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa
Expand All @@ -261,34 +263,42 @@ def get(self, command: str) -> ResponseT:
self._update_access(command)
return copy.deepcopy(self.cache[command]["response"])

def delete_command(self, command: str):
def delete_command(self, command: Union[str, Iterable[str]]):
"""
Delete a redis command and its metadata from the cache.
Args:
command (str): The redis command to be deleted.
command (Union[str, Iterable[str]]): The redis command to be deleted.
"""
if command in self.cache:
keys_in_command = self.cache[command].get("keys")
self._del_key_commands_map(keys_in_command, command)
self.commands_ttl_list.remove(command)
del self.cache[command]

def delete_many(self, commands):
pass
def delete_many(self, commands: List[Union[str, Iterable[str]]]):
"""
Delete multiple commands and their metadata from the cache.
Args:
commands (List[Union[str, Iterable[str]]]): The list of commands to be
deleted.
"""
for command in commands:
self.delete_command(command)

def flush(self):
"""Clear the entire cache, removing all redis commands and metadata."""
self.cache.clear()
self.key_commands_map.clear()
self.commands_ttl_list = []

def _is_expired(self, command: str) -> bool:
def _is_expired(self, command: Union[str, Iterable[str]]) -> bool:
"""
Check if a redis command has expired based on its time-to-live.
Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
Returns:
bool: True if the command has expired, False otherwise.
Expand All @@ -297,12 +307,12 @@ def _is_expired(self, command: str) -> bool:
return False
return time.monotonic() - self.cache[command]["ctime"] > self.ttl

def _update_access(self, command: str):
def _update_access(self, command: Union[str, Iterable[str]]):
"""
Update the access information for a redis command based on the eviction policy.
Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
if self.eviction_policy == EvictionPolicy.LRU.value:
self.cache.move_to_end(command)
Expand All @@ -329,24 +339,28 @@ def _evict(self):
random_command = random.choice(list(self.cache.keys()))
self.cache.pop(random_command)

def _update_key_commands_map(self, keys: List[KeyT], command: str):
def _update_key_commands_map(
self, keys: List[KeyT], command: Union[str, Iterable[str]]
):
"""
Update the key_commands_map with command that uses the keys.
Args:
keys (List[KeyT]): The list of keys used in the command.
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].add(command)

def _del_key_commands_map(self, keys: List[KeyT], command: str):
def _del_key_commands_map(
self, keys: List[KeyT], command: Union[str, Iterable[str]]
):
"""
Remove a redis command from the key_commands_map.
Args:
keys (List[KeyT]): The list of keys used in the redis command.
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].remove(command)
Expand Down
42 changes: 42 additions & 0 deletions tests/conftest.py
Expand Up @@ -9,6 +9,7 @@
import pytest
import redis
from packaging.version import Version
from redis import Sentinel
from redis.backoff import NoBackoff
from redis.connection import Connection, parse_url
from redis.exceptions import RedisClusterException
Expand Down Expand Up @@ -105,6 +106,19 @@ def pytest_addoption(parser):
"--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop"
)

parser.addoption(
"--sentinels",
action="store",
default="localhost:26379,localhost:26380,localhost:26381",
help="Comma-separated list of sentinel IPs and ports",
)
parser.addoption(
"--master-service",
action="store",
default="redis-py-test",
help="Name of the Redis master service that the sentinels are monitoring",
)


def _get_info(redis_url):
client = redis.Redis.from_url(redis_url)
Expand Down Expand Up @@ -352,6 +366,34 @@ def sslclient(request):
yield client


@pytest.fixture()
def sentinel_setup(local_cache, request):
sentinel_ips = request.config.getoption("--sentinels")
sentinel_endpoints = [
(ip.strip(), int(port.strip()))
for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(","))
]
kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {}
sentinel = Sentinel(
sentinel_endpoints,
socket_timeout=0.1,
client_cache=local_cache,
protocol=3,
**kwargs,
)
yield sentinel
for s in sentinel.sentinels:
s.close()


@pytest.fixture()
def master(request, sentinel_setup):
master_service = request.config.getoption("--master-service")
master = sentinel_setup.master_for(master_service)
yield master
master.close()


def _gen_cluster_mock_resp(r, response):
connection = Mock(spec=Connection)
connection.retry = Retry(NoBackoff(), 0)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_asyncio/conftest.py
Expand Up @@ -7,6 +7,7 @@
import redis.asyncio as redis
from packaging.version import Version
from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser
from redis.asyncio import Sentinel
from redis.asyncio.client import Monitor
from redis.asyncio.connection import Connection, parse_url
from redis.asyncio.retry import Retry
Expand Down Expand Up @@ -136,6 +137,34 @@ async def decoded_r(create_redis):
return await create_redis(decode_responses=True)


@pytest_asyncio.fixture()
async def sentinel_setup(local_cache, request):
sentinel_ips = request.config.getoption("--sentinels")
sentinel_endpoints = [
(ip.strip(), int(port.strip()))
for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(","))
]
kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {}
sentinel = Sentinel(
sentinel_endpoints,
socket_timeout=0.1,
client_cache=local_cache,
protocol=3,
**kwargs,
)
yield sentinel
for s in sentinel.sentinels:
await s.close()


@pytest_asyncio.fixture()
async def master(request, sentinel_setup):
master_service = request.config.getoption("--master-service")
master = sentinel_setup.master_for(master_service)
yield master
await master.close()


def _gen_cluster_mock_resp(r, response):
connection = mock.AsyncMock(spec=Connection)
connection.retry = Retry(NoBackoff(), 0)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_asyncio/test_cache.py
Expand Up @@ -14,6 +14,11 @@ async def r(request, create_redis):
yield r, cache


@pytest_asyncio.fixture()
async def local_cache():
yield _LocalCache()


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
class TestLocalCache:
@pytest.mark.onlynoncluster
Expand Down Expand Up @@ -228,3 +233,43 @@ async def test_cache_decode_response(self, r):
assert cache.get(("GET", "foo")) is None
# get key from redis
assert await r.get("foo") == "barbar"


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@pytest.mark.onlynoncluster
class TestSentinelLocalCache:

async def test_get_from_cache(self, local_cache, master):
await master.set("foo", "bar")
# get key from redis and save in local cache
assert await master.get("foo") == b"bar"
# get key from local cache
assert local_cache.get(("GET", "foo")) == b"bar"
# change key in redis (cause invalidation)
await master.set("foo", "barbar")
# send any command to redis (process invalidation in background)
await master.ping()
# the command is not in the local cache anymore
assert local_cache.get(("GET", "foo")) is None
# get key from redis
assert await master.get("foo") == b"barbar"

@pytest.mark.parametrize(
"sentinel_setup",
[{"kwargs": {"decode_responses": True}}],
indirect=True,
)
async def test_cache_decode_response(self, local_cache, sentinel_setup, master):
await master.set("foo", "bar")
# get key from redis and save in local cache
assert await master.get("foo") == "bar"
# get key from local cache
assert local_cache.get(("GET", "foo")) == "bar"
# change key in redis (cause invalidation)
await master.set("foo", "barbar")
# send any command to redis (process invalidation in background)
await master.ping()
# the command is not in the local cache anymore
assert local_cache.get(("GET", "foo")) is None
# get key from redis
assert await master.get("foo") == "barbar"

0 comments on commit a55cdb1

Please sign in to comment.