Skip to content

Commit

Permalink
Streamline client side caching API typing (#3216)
Browse files Browse the repository at this point in the history
* Streamline client side caching API typing

Streamline the typing of the client side caching API. Some of the
methods are defining commands of type `str`, while in reality tuples
are being sent for those parameters.

Add client side cache tests for Sentinels. In order to make this work,
fix the sentinel configuration in the docker-compose stack.

Add a test for client side caching with a truly custom cache, not just
injecting our internal cache structure as custom.

Add a test for client side caching where two different types of commands
use the same key, to make sure they invalidate each others cached data.

* Fixes after running tests against RE

* More test cases

* Fix async tests

* Tests for raw commands

* Change terminology for allow/deny lists

* Add test for single connection

* Make sure flushing the cache works everywhere

* Reenable some tests for cluster too

* Align cache typings at abstract level

* Use Sequence instead of Iterable for types

* Remove some exceptions in favor of ifs

---------

Co-authored-by: Gabriel Erzse <gabriel.erzse@redis.com>
  • Loading branch information
gerzse and Gabriel Erzse committed May 8, 2024
1 parent 64f291f commit 3bd311c
Show file tree
Hide file tree
Showing 13 changed files with 861 additions and 186 deletions.
1 change: 1 addition & 0 deletions dev_requirements.txt
@@ -1,5 +1,6 @@
click==8.0.4
black==24.3.0
cachetools
flake8==5.0.4
flake8-isort==6.0.0
flynt~=0.69.0
Expand Down
3 changes: 2 additions & 1 deletion dockers/sentinel.conf
@@ -1,4 +1,5 @@
sentinel monitor redis-py-test 127.0.0.1 6379 2
sentinel resolve-hostnames yes
sentinel monitor redis-py-test redis 6379 2
sentinel down-after-milliseconds redis-py-test 5000
sentinel failover-timeout redis-py-test 60000
sentinel parallel-syncs redis-py-test 1
96 changes: 58 additions & 38 deletions redis/_cache.py
Expand Up @@ -4,14 +4,20 @@
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import List
from typing import List, Sequence, Union

from redis.typing import KeyT, ResponseT

DEFAULT_EVICTION_POLICY = "lru"

class EvictionPolicy(Enum):
LRU = "lru"
LFU = "lfu"
RANDOM = "random"

DEFAULT_BLACKLIST = [

DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU

DEFAULT_DENY_LIST = [
"BF.CARD",
"BF.DEBUG",
"BF.EXISTS",
Expand Down Expand Up @@ -71,8 +77,7 @@
"TTL",
]


DEFAULT_WHITELIST = [
DEFAULT_ALLOW_LIST = [
"BITCOUNT",
"BITFIELD_RO",
"BITPOS",
Expand Down Expand Up @@ -155,32 +160,31 @@
_ACCESS_COUNT = "access_count"


class EvictionPolicy(Enum):
LRU = "lru"
LFU = "lfu"
RANDOM = "random"


class AbstractCache(ABC):
"""
An abstract base class for client caching implementations.
If you want to implement your own cache you must support these methods.
"""

@abstractmethod
def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
def set(
self,
command: Union[str, Sequence[str]],
response: ResponseT,
keys_in_command: List[KeyT],
):
pass

@abstractmethod
def get(self, command: str) -> ResponseT:
def get(self, command: Union[str, Sequence[str]]) -> ResponseT:
pass

@abstractmethod
def delete_command(self, command: str):
def delete_command(self, command: Union[str, Sequence[str]]):
pass

@abstractmethod
def delete_many(self, commands):
def delete_commands(self, commands: List[Union[str, Sequence[str]]]):
pass

@abstractmethod
Expand Down Expand Up @@ -215,7 +219,6 @@ def __init__(
max_size: int = 10000,
ttl: int = 0,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
**kwargs,
):
self.max_size = max_size
self.ttl = ttl
Expand All @@ -224,12 +227,17 @@ def __init__(
self.key_commands_map = defaultdict(set)
self.commands_ttl_list = []

def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
def set(
self,
command: Union[str, Sequence[str]],
response: ResponseT,
keys_in_command: List[KeyT],
):
"""
Set a redis command and its response in the cache.
Args:
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
response (ResponseT): The response associated with the command.
keys_in_command (List[KeyT]): The list of keys used in the command.
"""
Expand All @@ -244,12 +252,12 @@ def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]):
self._update_key_commands_map(keys_in_command, command)
self.commands_ttl_list.append(command)

def get(self, command: str) -> ResponseT:
def get(self, command: Union[str, Sequence[str]]) -> ResponseT:
"""
Get the response for a redis command from the cache.
Args:
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
Returns:
ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa
Expand All @@ -261,34 +269,42 @@ def get(self, command: str) -> ResponseT:
self._update_access(command)
return copy.deepcopy(self.cache[command]["response"])

def delete_command(self, command: str):
def delete_command(self, command: Union[str, Sequence[str]]):
"""
Delete a redis command and its metadata from the cache.
Args:
command (str): The redis command to be deleted.
command (Union[str, Sequence[str]]): The redis command to be deleted.
"""
if command in self.cache:
keys_in_command = self.cache[command].get("keys")
self._del_key_commands_map(keys_in_command, command)
self.commands_ttl_list.remove(command)
del self.cache[command]

def delete_many(self, commands):
pass
def delete_commands(self, commands: List[Union[str, Sequence[str]]]):
"""
Delete multiple commands and their metadata from the cache.
Args:
commands (List[Union[str, Sequence[str]]]): The list of commands to be
deleted.
"""
for command in commands:
self.delete_command(command)

def flush(self):
"""Clear the entire cache, removing all redis commands and metadata."""
self.cache.clear()
self.key_commands_map.clear()
self.commands_ttl_list = []

def _is_expired(self, command: str) -> bool:
def _is_expired(self, command: Union[str, Sequence[str]]) -> bool:
"""
Check if a redis command has expired based on its time-to-live.
Args:
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
Returns:
bool: True if the command has expired, False otherwise.
Expand All @@ -297,56 +313,60 @@ def _is_expired(self, command: str) -> bool:
return False
return time.monotonic() - self.cache[command]["ctime"] > self.ttl

def _update_access(self, command: str):
def _update_access(self, command: Union[str, Sequence[str]]):
"""
Update the access information for a redis command based on the eviction policy.
Args:
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
"""
if self.eviction_policy == EvictionPolicy.LRU.value:
if self.eviction_policy == EvictionPolicy.LRU:
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.LFU.value:
elif self.eviction_policy == EvictionPolicy.LFU:
self.cache[command]["access_count"] = (
self.cache.get(command, {}).get("access_count", 0) + 1
)
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
elif self.eviction_policy == EvictionPolicy.RANDOM:
pass # Random eviction doesn't require updates

def _evict(self):
"""Evict a redis command from the cache based on the eviction policy."""
if self._is_expired(self.commands_ttl_list[0]):
self.delete_command(self.commands_ttl_list[0])
elif self.eviction_policy == EvictionPolicy.LRU.value:
elif self.eviction_policy == EvictionPolicy.LRU:
self.cache.popitem(last=False)
elif self.eviction_policy == EvictionPolicy.LFU.value:
elif self.eviction_policy == EvictionPolicy.LFU:
min_access_command = min(
self.cache, key=lambda k: self.cache[k].get("access_count", 0)
)
self.cache.pop(min_access_command)
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
elif self.eviction_policy == EvictionPolicy.RANDOM:
random_command = random.choice(list(self.cache.keys()))
self.cache.pop(random_command)

def _update_key_commands_map(self, keys: List[KeyT], command: str):
def _update_key_commands_map(
self, keys: List[KeyT], command: Union[str, Sequence[str]]
):
"""
Update the key_commands_map with command that uses the keys.
Args:
keys (List[KeyT]): The list of keys used in the command.
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].add(command)

def _del_key_commands_map(self, keys: List[KeyT], command: str):
def _del_key_commands_map(
self, keys: List[KeyT], command: Union[str, Sequence[str]]
):
"""
Remove a redis command from the key_commands_map.
Args:
keys (List[KeyT]): The list of keys used in the redis command.
command (str): The redis command.
command (Union[str, Sequence[str]]): The redis command.
"""
for key in keys:
self.key_commands_map[key].remove(command)
Expand Down
48 changes: 20 additions & 28 deletions redis/asyncio/client.py
Expand Up @@ -27,9 +27,9 @@
)

from redis._cache import (
DEFAULT_BLACKLIST,
DEFAULT_ALLOW_LIST,
DEFAULT_DENY_LIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
AbstractCache,
)
from redis._parsers.helpers import (
Expand Down Expand Up @@ -243,8 +243,8 @@ def __init__(
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
cache_deny_list: List[str] = DEFAULT_DENY_LIST,
cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
):
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -299,8 +299,8 @@ def __init__(
"cache_max_size": cache_max_size,
"cache_ttl": cache_ttl,
"cache_policy": cache_policy,
"cache_blacklist": cache_blacklist,
"cache_whitelist": cache_whitelist,
"cache_deny_list": cache_deny_list,
"cache_allow_list": cache_allow_list,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
Expand Down Expand Up @@ -640,7 +640,8 @@ async def execute_command(self, *args, **options):
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
if keys:
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
Expand Down Expand Up @@ -675,31 +676,22 @@ async def parse_response(
return response

def flush_cache(self):
try:
if self.connection:
self.connection.client_cache.flush()
else:
self.connection_pool.flush_cache()
except AttributeError:
pass
if self.connection:
self.connection.flush_cache()
else:
self.connection_pool.flush_cache()

def delete_command_from_cache(self, command):
try:
if self.connection:
self.connection.client_cache.delete_command(command)
else:
self.connection_pool.delete_command_from_cache(command)
except AttributeError:
pass
if self.connection:
self.connection.delete_command_from_cache(command)
else:
self.connection_pool.delete_command_from_cache(command)

def invalidate_key_from_cache(self, key):
try:
if self.connection:
self.connection.client_cache.invalidate_key(key)
else:
self.connection_pool.invalidate_key_from_cache(key)
except AttributeError:
pass
if self.connection:
self.connection.invalidate_key_from_cache(key)
else:
self.connection_pool.invalidate_key_from_cache(key)


StrictRedis = Redis
Expand Down

0 comments on commit 3bd311c

Please sign in to comment.