Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamline client side caching API typing #3216

Merged
merged 12 commits into from May 8, 2024
1 change: 1 addition & 0 deletions dev_requirements.txt
@@ -1,5 +1,6 @@
click==8.0.4
black==24.3.0
cachetools
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
flake8==5.0.4
flake8-isort==6.0.0
flynt~=0.69.0
Expand Down
3 changes: 2 additions & 1 deletion dockers/sentinel.conf
@@ -1,4 +1,5 @@
sentinel monitor redis-py-test 127.0.0.1 6379 2
sentinel resolve-hostnames yes
sentinel monitor redis-py-test redis 6379 2
sentinel down-after-milliseconds redis-py-test 5000
sentinel failover-timeout redis-py-test 60000
sentinel parallel-syncs redis-py-test 1
85 changes: 50 additions & 35 deletions redis/_cache.py
Expand Up @@ -4,14 +4,20 @@
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import List
from typing import Iterable, List, Union

from redis.typing import KeyT, ResponseT

DEFAULT_EVICTION_POLICY = "lru"

class EvictionPolicy(Enum):
LRU = "lru"
LFU = "lfu"
RANDOM = "random"

DEFAULT_BLACKLIST = [

DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU

DEFAULT_DENY_LIST = [
"BF.CARD",
"BF.DEBUG",
"BF.EXISTS",
Expand Down Expand Up @@ -71,8 +77,7 @@
"TTL",
]


DEFAULT_WHITELIST = [
DEFAULT_ALLOW_LIST = [
"BITCOUNT",
"BITFIELD_RO",
"BITPOS",
Expand Down Expand Up @@ -155,12 +160,6 @@
_ACCESS_COUNT = "access_count"


class EvictionPolicy(Enum):
LRU = "lru"
LFU = "lfu"
RANDOM = "random"


class AbstractCache(ABC):
"""
An abstract base class for client caching implementations.
Expand All @@ -180,7 +179,7 @@ def delete_command(self, command: str):
pass

@abstractmethod
def delete_many(self, commands):
def delete_commands(self, commands):
pass

@abstractmethod
Expand Down Expand Up @@ -215,7 +214,6 @@ def __init__(
max_size: int = 10000,
ttl: int = 0,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
**kwargs,
):
self.max_size = max_size
self.ttl = ttl
Expand All @@ -224,12 +222,17 @@ def __init__(
self.key_commands_map = defaultdict(set)
self.commands_ttl_list = []

def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
def set(
self,
command: Union[str, Iterable[str]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it need to be also Iterable[str]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in all unit tests tuples are being sent for this parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in unit tests, that is for the get. In Connection we have this signature for example:

def _add_to_local_cache(
    self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe we need only Iterable[str]? I'm not sure...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a broader discussion, I think we can leave it with both for the beta release. In general I think we should not expose to the outside world the way we represent commands in the cache. We should only expose the concept of keys externally, and do the mapping to commands and back only internally, and then we have full freedom of representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used Sequence instead of Iterable and Tuple. I think it is the most expressive one.

response: ResponseT,
keys_in_command: List[KeyT],
):
"""
Set a redis command and its response in the cache.

Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
response (ResponseT): The response associated with the command.
keys_in_command (List[KeyT]): The list of keys used in the command.
"""
Expand All @@ -244,12 +247,12 @@ def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
self._update_key_commands_map(keys_in_command, command)
self.commands_ttl_list.append(command)

def get(self, command: str) -> ResponseT:
def get(self, command: Union[str, Iterable[str]]) -> ResponseT:
"""
Get the response for a redis command from the cache.

Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.

Returns:
ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa
Expand All @@ -261,34 +264,42 @@ def get(self, command: str) -> ResponseT:
self._update_access(command)
return copy.deepcopy(self.cache[command]["response"])

def delete_command(self, command: str):
def delete_command(self, command: Union[str, Iterable[str]]):
"""
Delete a redis command and its metadata from the cache.

Args:
command (str): The redis command to be deleted.
command (Union[str, Iterable[str]]): The redis command to be deleted.
"""
if command in self.cache:
keys_in_command = self.cache[command].get("keys")
self._del_key_commands_map(keys_in_command, command)
self.commands_ttl_list.remove(command)
del self.cache[command]

def delete_many(self, commands):
pass
def delete_commands(self, commands: List[Union[str, Iterable[str]]]):
"""
Delete multiple commands and their metadata from the cache.

Args:
commands (List[Union[str, Iterable[str]]]): The list of commands to be
deleted.
"""
for command in commands:
self.delete_command(command)

def flush(self):
"""Clear the entire cache, removing all redis commands and metadata."""
self.cache.clear()
self.key_commands_map.clear()
self.commands_ttl_list = []

def _is_expired(self, command: str) -> bool:
def _is_expired(self, command: Union[str, Iterable[str]]) -> bool:
"""
Check if a redis command has expired based on its time-to-live.

Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.

Returns:
bool: True if the command has expired, False otherwise.
Expand All @@ -297,56 +308,60 @@ def _is_expired(self, command: str) -> bool:
return False
return time.monotonic() - self.cache[command]["ctime"] > self.ttl

def _update_access(self, command: str):
def _update_access(self, command: Union[str, Iterable[str]]):
"""
Update the access information for a redis command based on the eviction policy.

Args:
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
if self.eviction_policy == EvictionPolicy.LRU.value:
if self.eviction_policy == EvictionPolicy.LRU:
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.LFU.value:
elif self.eviction_policy == EvictionPolicy.LFU:
self.cache[command]["access_count"] = (
self.cache.get(command, {}).get("access_count", 0) + 1
)
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
elif self.eviction_policy == EvictionPolicy.RANDOM:
pass # Random eviction doesn't require updates

def _evict(self):
"""Evict a redis command from the cache based on the eviction policy."""
if self._is_expired(self.commands_ttl_list[0]):
self.delete_command(self.commands_ttl_list[0])
elif self.eviction_policy == EvictionPolicy.LRU.value:
elif self.eviction_policy == EvictionPolicy.LRU:
self.cache.popitem(last=False)
elif self.eviction_policy == EvictionPolicy.LFU.value:
elif self.eviction_policy == EvictionPolicy.LFU:
min_access_command = min(
self.cache, key=lambda k: self.cache[k].get("access_count", 0)
)
self.cache.pop(min_access_command)
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
elif self.eviction_policy == EvictionPolicy.RANDOM:
random_command = random.choice(list(self.cache.keys()))
self.cache.pop(random_command)

def _update_key_commands_map(self, keys: List[KeyT], command: str):
def _update_key_commands_map(
self, keys: List[KeyT], command: Union[str, Iterable[str]]
):
"""
Update the key_commands_map with command that uses the keys.

Args:
keys (List[KeyT]): The list of keys used in the command.
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].add(command)

def _del_key_commands_map(self, keys: List[KeyT], command: str):
def _del_key_commands_map(
self, keys: List[KeyT], command: Union[str, Iterable[str]]
):
"""
Remove a redis command from the key_commands_map.

Args:
keys (List[KeyT]): The list of keys used in the redis command.
command (str): The redis command.
command (Union[str, Iterable[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].remove(command)
Expand Down
15 changes: 8 additions & 7 deletions redis/asyncio/client.py
Expand Up @@ -27,9 +27,9 @@
)

from redis._cache import (
DEFAULT_BLACKLIST,
DEFAULT_ALLOW_LIST,
DEFAULT_DENY_LIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
AbstractCache,
)
from redis._parsers.helpers import (
Expand Down Expand Up @@ -243,8 +243,8 @@ def __init__(
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
cache_deny_list: List[str] = DEFAULT_DENY_LIST,
cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
):
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -299,8 +299,8 @@ def __init__(
"cache_max_size": cache_max_size,
"cache_ttl": cache_ttl,
"cache_policy": cache_policy,
"cache_blacklist": cache_blacklist,
"cache_whitelist": cache_whitelist,
"cache_deny_list": cache_deny_list,
"cache_allow_list": cache_allow_list,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
Expand Down Expand Up @@ -640,7 +640,8 @@ async def execute_command(self, *args, **options):
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
if keys:
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
Expand Down
15 changes: 8 additions & 7 deletions redis/asyncio/cluster.py
Expand Up @@ -20,9 +20,9 @@
)

from redis._cache import (
DEFAULT_BLACKLIST,
DEFAULT_ALLOW_LIST,
DEFAULT_DENY_LIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
AbstractCache,
)
from redis._parsers import AsyncCommandsParser, Encoder
Expand Down Expand Up @@ -280,8 +280,8 @@ def __init__(
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
cache_deny_list: List[str] = DEFAULT_DENY_LIST,
cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -331,8 +331,8 @@ def __init__(
"cache_max_size": cache_max_size,
"cache_ttl": cache_ttl,
"cache_policy": cache_policy,
"cache_blacklist": cache_blacklist,
"cache_whitelist": cache_whitelist,
"cache_deny_list": cache_deny_list,
"cache_allow_list": cache_allow_list,
}

if ssl:
Expand Down Expand Up @@ -1075,7 +1075,8 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
# Read response
try:
response = await self.parse_response(connection, args[0], **kwargs)
connection._add_to_local_cache(args, response, keys)
if keys:
connection._add_to_local_cache(args, response, keys)
return response
finally:
# Release connection
Expand Down
26 changes: 13 additions & 13 deletions redis/asyncio/connection.py
Expand Up @@ -51,9 +51,9 @@
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes

from .._cache import (
DEFAULT_BLACKLIST,
DEFAULT_ALLOW_LIST,
DEFAULT_DENY_LIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
AbstractCache,
_LocalCache,
)
Expand Down Expand Up @@ -120,8 +120,8 @@ class AbstractConnection:
"ssl_context",
"protocol",
"client_cache",
"cache_blacklist",
"cache_whitelist",
"cache_deny_list",
"cache_allow_list",
"_reader",
"_writer",
"_parser",
Expand Down Expand Up @@ -161,8 +161,8 @@ def __init__(
cache_max_size: int = 10000,
cache_ttl: int = 0,
cache_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
cache_deny_list: List[str] = DEFAULT_DENY_LIST,
cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
):
if (username or password) and credential_provider is not None:
raise DataError(
Expand Down Expand Up @@ -230,8 +230,8 @@ def __init__(
raise RedisError(
"client caching is only supported with protocol version 3 or higher"
)
self.cache_blacklist = cache_blacklist
self.cache_whitelist = cache_whitelist
self.cache_deny_list = cache_deny_list
self.cache_allow_list = cache_allow_list

def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
Expand Down Expand Up @@ -696,7 +696,7 @@ def _cache_invalidation_process(
and the second string is the list of keys to invalidate.
(if the list of keys is None, then all keys are invalidated)
"""
if data[1] is not None:
if data[1] is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surfaced due to the new tests, basically luck.

self.client_cache.flush()
else:
for key in data[1]:
Expand All @@ -708,8 +708,8 @@ async def _get_from_local_cache(self, command: str):
"""
if (
self.client_cache is None
or command[0] in self.cache_blacklist
or command[0] not in self.cache_whitelist
or command[0] in self.cache_deny_list
or command[0] not in self.cache_allow_list
):
return None
while not self._socket_is_empty():
Expand All @@ -725,8 +725,8 @@ def _add_to_local_cache(
"""
if (
self.client_cache is not None
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list)
and (self.cache_allow_list == [] or command[0] in self.cache_allow_list)
):
self.client_cache.set(command, response, keys)

Expand Down