Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamline client side caching API typing #3216

Merged
merged 12 commits into from May 8, 2024
1 change: 1 addition & 0 deletions dev_requirements.txt
@@ -1,5 +1,6 @@
click==8.0.4
black==24.3.0
cachetools
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
flake8==5.0.4
flake8-isort==6.0.0
flynt~=0.69.0
Expand Down
3 changes: 2 additions & 1 deletion dockers/sentinel.conf
@@ -1,4 +1,5 @@
sentinel monitor redis-py-test 127.0.0.1 6379 2
sentinel resolve-hostnames yes
sentinel monitor redis-py-test redis 6379 2
sentinel down-after-milliseconds redis-py-test 5000
sentinel failover-timeout redis-py-test 60000
sentinel parallel-syncs redis-py-test 1
56 changes: 35 additions & 21 deletions redis/_cache.py
Expand Up @@ -4,13 +4,12 @@
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import List
from typing import Iterable, List, Union

from redis.typing import KeyT, ResponseT

DEFAULT_EVICTION_POLICY = "lru"


DEFAULT_BLACKLIST = [
"BF.CARD",
"BF.DEBUG",
Expand Down Expand Up @@ -71,7 +70,6 @@
"TTL",
]


DEFAULT_WHITELIST = [
"BITCOUNT",
"BITFIELD_RO",
Expand Down Expand Up @@ -180,7 +178,7 @@ def delete_command(self, command: str):
pass

@abstractmethod
def delete_many(self, commands):
def delete_commands(self, commands):
pass

@abstractmethod
Expand Down Expand Up @@ -215,7 +213,6 @@ def __init__(
max_size: int = 10000,
ttl: int = 0,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
**kwargs,
):
self.max_size = max_size
self.ttl = ttl
Expand All @@ -224,12 +221,17 @@ def __init__(
self.key_commands_map = defaultdict(set)
self.commands_ttl_list = []

def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
def set(
self,
command: Union[str, Iterable[str]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it need to be also Iterable[str]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in all unit tests tuples are being sent for this parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in unit tests, that is for the get. In Connection we have this signature for example:

def _add_to_local_cache(
    self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe we need only Iterable[str]? I'm not sure...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a broader discussion, I think we can leave it with both for the beta release. In general I think we should not expose to the outside world the way we represent commands in the cache. We should only expose the concept of keys externally, and do the mapping to commands and back only internally, and then we have full freedom of representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used Sequence instead of Iterable and Tuple. I think it is the most expressive one.

response: ResponseT,
keys_in_command: List[KeyT],
):
"""
Set a redis command and its response in the cache.

Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
response (ResponseT): The response associated with the command.
keys_in_command (List[KeyT]): The list of keys used in the command.
"""
Expand All @@ -244,12 +246,12 @@ def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
self._update_key_commands_map(keys_in_command, command)
self.commands_ttl_list.append(command)

def get(self, command: str) -> ResponseT:
def get(self, command: Union[str, Iterable[str]]) -> ResponseT:
"""
Get the response for a redis command from the cache.

Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.

Returns:
ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa
Expand All @@ -261,34 +263,42 @@ def get(self, command: str) -> ResponseT:
self._update_access(command)
return copy.deepcopy(self.cache[command]["response"])

def delete_command(self, command: str):
def delete_command(self, command: Union[str, Iterable[str]]):
"""
Delete a redis command and its metadata from the cache.

Args:
command (str): The redis command to be deleted.
command (Union[str, Iterable[str]]): The redis command to be deleted.
"""
if command in self.cache:
keys_in_command = self.cache[command].get("keys")
self._del_key_commands_map(keys_in_command, command)
self.commands_ttl_list.remove(command)
del self.cache[command]

def delete_many(self, commands):
pass
def delete_commands(self, commands: List[Union[str, Iterable[str]]]):
"""
Delete multiple commands and their metadata from the cache.

Args:
commands (List[Union[str, Iterable[str]]]): The list of commands to be
deleted.
"""
for command in commands:
self.delete_command(command)

def flush(self):
"""Clear the entire cache, removing all redis commands and metadata."""
self.cache.clear()
self.key_commands_map.clear()
self.commands_ttl_list = []

def _is_expired(self, command: str) -> bool:
def _is_expired(self, command: Union[str, Iterable[str]]) -> bool:
"""
Check if a redis command has expired based on its time-to-live.

Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.

Returns:
bool: True if the command has expired, False otherwise.
Expand All @@ -297,12 +307,12 @@ def _is_expired(self, command: str) -> bool:
return False
return time.monotonic() - self.cache[command]["ctime"] > self.ttl

def _update_access(self, command: str):
def _update_access(self, command: Union[str, Iterable[str]]):
"""
Update the access information for a redis command based on the eviction policy.

Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
if self.eviction_policy == EvictionPolicy.LRU.value:
self.cache.move_to_end(command)
Expand All @@ -329,24 +339,28 @@ def _evict(self):
random_command = random.choice(list(self.cache.keys()))
self.cache.pop(random_command)

def _update_key_commands_map(self, keys: List[KeyT], command: str):
def _update_key_commands_map(
self, keys: List[KeyT], command: Union[str, Iterable[str]]
):
"""
Update the key_commands_map with command that uses the keys.

Args:
keys (List[KeyT]): The list of keys used in the command.
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].add(command)

def _del_key_commands_map(self, keys: List[KeyT], command: str):
def _del_key_commands_map(
self, keys: List[KeyT], command: Union[str, Iterable[str]]
):
"""
Remove a redis command from the key_commands_map.

Args:
keys (List[KeyT]): The list of keys used in the redis command.
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].remove(command)
Expand Down
42 changes: 42 additions & 0 deletions tests/conftest.py
Expand Up @@ -9,6 +9,7 @@
import pytest
import redis
from packaging.version import Version
from redis import Sentinel
from redis.backoff import NoBackoff
from redis.connection import Connection, parse_url
from redis.exceptions import RedisClusterException
Expand Down Expand Up @@ -105,6 +106,19 @@ def pytest_addoption(parser):
"--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop"
)

parser.addoption(
"--sentinels",
action="store",
default="localhost:26379,localhost:26380,localhost:26381",
help="Comma-separated list of sentinel IPs and ports",
)
parser.addoption(
"--master-service",
action="store",
default="redis-py-test",
help="Name of the Redis master service that the sentinels are monitoring",
)


def _get_info(redis_url):
client = redis.Redis.from_url(redis_url)
Expand Down Expand Up @@ -352,6 +366,34 @@ def sslclient(request):
yield client


@pytest.fixture()
def sentinel_setup(local_cache, request):
sentinel_ips = request.config.getoption("--sentinels")
sentinel_endpoints = [
(ip.strip(), int(port.strip()))
for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(","))
]
kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {}
sentinel = Sentinel(
sentinel_endpoints,
socket_timeout=0.1,
client_cache=local_cache,
protocol=3,
**kwargs,
)
yield sentinel
for s in sentinel.sentinels:
s.close()


@pytest.fixture()
def master(request, sentinel_setup):
master_service = request.config.getoption("--master-service")
master = sentinel_setup.master_for(master_service)
yield master
master.close()


def _gen_cluster_mock_resp(r, response):
connection = Mock(spec=Connection)
connection.retry = Retry(NoBackoff(), 0)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_asyncio/conftest.py
Expand Up @@ -7,6 +7,7 @@
import redis.asyncio as redis
from packaging.version import Version
from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser
from redis.asyncio import Sentinel
from redis.asyncio.client import Monitor
from redis.asyncio.connection import Connection, parse_url
from redis.asyncio.retry import Retry
Expand Down Expand Up @@ -136,6 +137,34 @@ async def decoded_r(create_redis):
return await create_redis(decode_responses=True)


@pytest_asyncio.fixture()
async def sentinel_setup(local_cache, request):
sentinel_ips = request.config.getoption("--sentinels")
sentinel_endpoints = [
(ip.strip(), int(port.strip()))
for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(","))
]
kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {}
sentinel = Sentinel(
sentinel_endpoints,
socket_timeout=0.1,
client_cache=local_cache,
protocol=3,
**kwargs,
)
yield sentinel
for s in sentinel.sentinels:
await s.close()


@pytest_asyncio.fixture()
async def master(request, sentinel_setup):
master_service = request.config.getoption("--master-service")
master = sentinel_setup.master_for(master_service)
yield master
await master.close()


def _gen_cluster_mock_resp(r, response):
connection = mock.AsyncMock(spec=Connection)
connection.retry = Retry(NoBackoff(), 0)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_asyncio/test_cache.py
Expand Up @@ -14,6 +14,11 @@ async def r(request, create_redis):
yield r, cache


@pytest_asyncio.fixture()
async def local_cache():
yield _LocalCache()


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
class TestLocalCache:
@pytest.mark.onlynoncluster
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I think it will fail on cluster

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the decorator to class level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved back.

Expand Down Expand Up @@ -228,3 +233,43 @@ async def test_cache_decode_response(self, r):
assert cache.get(("GET", "foo")) is None
# get key from redis
assert await r.get("foo") == "barbar"


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@pytest.mark.onlynoncluster
class TestSentinelLocalCache:

async def test_get_from_cache(self, local_cache, master):
await master.set("foo", "bar")
# get key from redis and save in local cache
assert await master.get("foo") == b"bar"
# get key from local cache
assert local_cache.get(("GET", "foo")) == b"bar"
# change key in redis (cause invalidation)
await master.set("foo", "barbar")
# send any command to redis (process invalidation in background)
await master.ping()
vladvildanov marked this conversation as resolved.
Show resolved Hide resolved
# the command is not in the local cache anymore
assert local_cache.get(("GET", "foo")) is None
# get key from redis
assert await master.get("foo") == b"barbar"

@pytest.mark.parametrize(
"sentinel_setup",
[{"kwargs": {"decode_responses": True}}],
indirect=True,
)
async def test_cache_decode_response(self, local_cache, sentinel_setup, master):
await master.set("foo", "bar")
# get key from redis and save in local cache
assert await master.get("foo") == "bar"
# get key from local cache
assert local_cache.get(("GET", "foo")) == "bar"
# change key in redis (cause invalidation)
await master.set("foo", "barbar")
# send any command to redis (process invalidation in background)
await master.ping()
# the command is not in the local cache anymore
assert local_cache.get(("GET", "foo")) is None
# get key from redis
assert await master.get("foo") == "barbar"