Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamline client side caching API typing #3216

Merged
merged 12 commits into from May 8, 2024
1 change: 1 addition & 0 deletions dev_requirements.txt
@@ -1,5 +1,6 @@
click==8.0.4
black==24.3.0
cachetools
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
flake8==5.0.4
flake8-isort==6.0.0
flynt~=0.69.0
Expand Down
3 changes: 2 additions & 1 deletion dockers/sentinel.conf
@@ -1,4 +1,5 @@
sentinel monitor redis-py-test 127.0.0.1 6379 2
sentinel resolve-hostnames yes
sentinel monitor redis-py-test redis 6379 2
sentinel down-after-milliseconds redis-py-test 5000
sentinel failover-timeout redis-py-test 60000
sentinel parallel-syncs redis-py-test 1
96 changes: 58 additions & 38 deletions redis/_cache.py
Expand Up @@ -4,14 +4,20 @@
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import List
from typing import List, Sequence, Union

from redis.typing import KeyT, ResponseT

DEFAULT_EVICTION_POLICY = "lru"

class EvictionPolicy(Enum):
LRU = "lru"
LFU = "lfu"
RANDOM = "random"

DEFAULT_BLACKLIST = [

DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU

DEFAULT_DENY_LIST = [
"BF.CARD",
"BF.DEBUG",
"BF.EXISTS",
Expand Down Expand Up @@ -71,8 +77,7 @@
"TTL",
]


DEFAULT_WHITELIST = [
DEFAULT_ALLOW_LIST = [
"BITCOUNT",
"BITFIELD_RO",
"BITPOS",
Expand Down Expand Up @@ -155,32 +160,31 @@
_ACCESS_COUNT = "access_count"


class EvictionPolicy(Enum):
LRU = "lru"
LFU = "lfu"
RANDOM = "random"


class AbstractCache(ABC):
"""
An abstract base class for client caching implementations.
If you want to implement your own cache you must support these methods.
"""

@abstractmethod
def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
def set(
self,
command: Union[str, Sequence[str]],
response: ResponseT,
keys_in_command: List[KeyT],
):
pass

@abstractmethod
def get(self, command: str) -> ResponseT:
def get(self, command: Union[str, Sequence[str]]) -> ResponseT:
pass

@abstractmethod
def delete_command(self, command: str):
def delete_command(self, command: Union[str, Sequence[str]]):
pass

@abstractmethod
def delete_many(self, commands):
def delete_commands(self, commands: List[Union[str, Sequence[str]]]):
pass

@abstractmethod
Expand Down Expand Up @@ -215,7 +219,6 @@ def __init__(
max_size: int = 10000,
ttl: int = 0,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
**kwargs,
):
self.max_size = max_size
self.ttl = ttl
Expand All @@ -224,12 +227,17 @@ def __init__(
self.key_commands_map = defaultdict(set)
self.commands_ttl_list = []

def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
def set(
self,
command: Union[str, Sequence[str]],
response: ResponseT,
keys_in_command: List[KeyT],
):
"""
Set a redis command and its response in the cache.

Args:
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
response (ResponseT): The response associated with the command.
keys_in_command (List[KeyT]): The list of keys used in the command.
"""
Expand All @@ -244,12 +252,12 @@ def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
self._update_key_commands_map(keys_in_command, command)
self.commands_ttl_list.append(command)

def get(self, command: str) -> ResponseT:
def get(self, command: Union[str, Sequence[str]]) -> ResponseT:
"""
Get the response for a redis command from the cache.

Args:
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.

Returns:
ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa
Expand All @@ -261,34 +269,42 @@ def get(self, command: str) -> ResponseT:
self._update_access(command)
return copy.deepcopy(self.cache[command]["response"])

def delete_command(self, command: str):
def delete_command(self, command: Union[str, Sequence[str]]):
"""
Delete a redis command and its metadata from the cache.

Args:
command (str): The redis command to be deleted.
command (Union[str, Sequence[str]]): The redis command to be deleted.
"""
if command in self.cache:
keys_in_command = self.cache[command].get("keys")
self._del_key_commands_map(keys_in_command, command)
self.commands_ttl_list.remove(command)
del self.cache[command]

def delete_many(self, commands):
pass
def delete_commands(self, commands: List[Union[str, Sequence[str]]]):
"""
Delete multiple commands and their metadata from the cache.

Args:
commands (List[Union[str, Sequence[str]]]): The list of commands to be
deleted.
"""
for command in commands:
self.delete_command(command)

def flush(self):
"""Clear the entire cache, removing all redis commands and metadata."""
self.cache.clear()
self.key_commands_map.clear()
self.commands_ttl_list = []

def _is_expired(self, command: str) -> bool:
def _is_expired(self, command: Union[str, Sequence[str]]) -> bool:
"""
Check if a redis command has expired based on its time-to-live.

Args:
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.

Returns:
bool: True if the command has expired, False otherwise.
Expand All @@ -297,56 +313,60 @@ def _is_expired(self, command: str) -> bool:
return False
return time.monotonic() - self.cache[command]["ctime"] > self.ttl

def _update_access(self, command: str):
def _update_access(self, command: Union[str, Sequence[str]]):
"""
Update the access information for a redis command based on the eviction policy.

Args:
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
"""
if self.eviction_policy == EvictionPolicy.LRU.value:
if self.eviction_policy == EvictionPolicy.LRU:
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.LFU.value:
elif self.eviction_policy == EvictionPolicy.LFU:
self.cache[command]["access_count"] = (
self.cache.get(command, {}).get("access_count", 0) + 1
)
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
elif self.eviction_policy == EvictionPolicy.RANDOM:
pass # Random eviction doesn't require updates

def _evict(self):
"""Evict a redis command from the cache based on the eviction policy."""
if self._is_expired(self.commands_ttl_list[0]):
self.delete_command(self.commands_ttl_list[0])
elif self.eviction_policy == EvictionPolicy.LRU.value:
elif self.eviction_policy == EvictionPolicy.LRU:
self.cache.popitem(last=False)
elif self.eviction_policy == EvictionPolicy.LFU.value:
elif self.eviction_policy == EvictionPolicy.LFU:
min_access_command = min(
self.cache, key=lambda k: self.cache[k].get("access_count", 0)
)
self.cache.pop(min_access_command)
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
elif self.eviction_policy == EvictionPolicy.RANDOM:
random_command = random.choice(list(self.cache.keys()))
self.cache.pop(random_command)

def _update_key_commands_map(self, keys: List[KeyT], command: str):
def _update_key_commands_map(
self, keys: List[KeyT], command: Union[str, Sequence[str]]
):
"""
Update the key_commands_map with command that uses the keys.

Args:
keys (List[KeyT]): The list of keys used in the command.
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].add(command)

def _del_key_commands_map(self, keys: List[KeyT], command: str):
def _del_key_commands_map(
self, keys: List[KeyT], command: Union[str, Sequence[str]]
):
"""
Remove a redis command from the key_commands_map.

Args:
keys (List[KeyT]): The list of keys used in the redis command.
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].remove(command)
Expand Down
48 changes: 20 additions & 28 deletions redis/asyncio/client.py
Expand Up @@ -27,9 +27,9 @@
)

from redis._cache import (
DEFAULT_BLACKLIST,
DEFAULT_ALLOW_LIST,
DEFAULT_DENY_LIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
AbstractCache,
)
from redis._parsers.helpers import (
Expand Down Expand Up @@ -243,8 +243,8 @@ def __init__(
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
cache_deny_list: List[str] = DEFAULT_DENY_LIST,
cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
):
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -299,8 +299,8 @@ def __init__(
"cache_max_size": cache_max_size,
"cache_ttl": cache_ttl,
"cache_policy": cache_policy,
"cache_blacklist": cache_blacklist,
"cache_whitelist": cache_whitelist,
"cache_deny_list": cache_deny_list,
"cache_allow_list": cache_allow_list,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
Expand Down Expand Up @@ -640,7 +640,8 @@ async def execute_command(self, *args, **options):
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
if keys:
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
Expand Down Expand Up @@ -675,31 +676,22 @@ async def parse_response(
return response

def flush_cache(self):
try:
if self.connection:
self.connection.client_cache.flush()
else:
self.connection_pool.flush_cache()
except AttributeError:
pass
if self.connection:
self.connection.flush_cache()
else:
self.connection_pool.flush_cache()

def delete_command_from_cache(self, command):
try:
if self.connection:
self.connection.client_cache.delete_command(command)
else:
self.connection_pool.delete_command_from_cache(command)
except AttributeError:
pass
if self.connection:
self.connection.delete_command_from_cache(command)
else:
self.connection_pool.delete_command_from_cache(command)

def invalidate_key_from_cache(self, key):
try:
if self.connection:
self.connection.client_cache.invalidate_key(key)
else:
self.connection_pool.invalidate_key_from_cache(key)
except AttributeError:
pass
if self.connection:
self.connection.invalidate_key_from_cache(key)
else:
self.connection_pool.invalidate_key_from_cache(key)


StrictRedis = Redis
Expand Down