Skip to content

Commit

Permalink
Support client side caching with RedisCluster (#3102)
Browse files Browse the repository at this point in the history
* sync

* fix mock_node_resp

* fix mock_node_resp_func

* fix test_handling_cluster_failover_to_a_replica

* fix test_handling_cluster_failover_to_a_replica

* async cluster and cleanup tests

* delete comment
  • Loading branch information
dvora-h committed Jan 9, 2024
1 parent b5d4d29 commit c7a13ae
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 268 deletions.
46 changes: 37 additions & 9 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
Union,
)

from redis._cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from redis._parsers import AsyncCommandsParser, Encoder
from redis._parsers.helpers import (
_RedisCallbacks,
Expand Down Expand Up @@ -267,6 +273,13 @@ def __init__(
ssl_keyfile: Optional[str] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
cache_enable: bool = False,
client_cache: Optional[_LocalCache] = None,
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_eviction_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -310,6 +323,14 @@ def __init__(
"socket_timeout": socket_timeout,
"retry": retry,
"protocol": protocol,
# Client cache related kwargs
"cache_enable": cache_enable,
"client_cache": client_cache,
"cache_max_size": cache_max_size,
"cache_ttl": cache_ttl,
"cache_eviction_policy": cache_eviction_policy,
"cache_blacklist": cache_blacklist,
"cache_whitelist": cache_whitelist,
}

if ssl:
Expand Down Expand Up @@ -682,7 +703,6 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
:raises RedisClusterException: if target_nodes is not provided & the command
can't be mapped to a slot
"""
kwargs.pop("keys", None) # the keys are used only for client side caching
command = args[0]
target_nodes = []
target_nodes_specified = False
Expand Down Expand Up @@ -1039,16 +1059,24 @@ async def parse_response(
async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
# Acquire connection
connection = self.acquire_connection()
keys = kwargs.pop("keys", None)

# Execute command
await connection.send_packed_command(connection.pack_command(*args), False)

# Read response
try:
return await self.parse_response(connection, args[0], **kwargs)
finally:
# Release connection
response_from_cache = await connection._get_from_local_cache(args)
if response_from_cache is not None:
self._free.append(connection)
return response_from_cache
else:
# Execute command
await connection.send_packed_command(connection.pack_command(*args), False)

# Read response
try:
response = await self.parse_response(connection, args[0], **kwargs)
connection._add_to_local_cache(args, response, keys)
return response
finally:
# Release connection
self._free.append(connection)

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
Expand Down
4 changes: 4 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def __init__(
_cache = None
self.client_cache = client_cache if client_cache is not None else _cache
if self.client_cache is not None:
if self.protocol not in [3, "3"]:
raise RedisError(
"client caching is only supported with protocol version 3 or higher"
)
self.cache_blacklist = cache_blacklist
self.cache_whitelist = cache_whitelist

Expand Down
29 changes: 20 additions & 9 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def parse_cluster_myshardid(resp, **options):
"ssl_password",
"unix_socket_path",
"username",
"cache_enable",
"client_cache",
"cache_max_size",
"cache_ttl",
"cache_eviction_policy",
"cache_blacklist",
"cache_whitelist",
)
KWARGS_DISABLED_KEYS = ("host", "port")

Expand Down Expand Up @@ -1060,7 +1067,6 @@ def execute_command(self, *args, **kwargs):
list<ClusterNode>
dict<Any, ClusterNode>
"""
kwargs.pop("keys", None) # the keys are used only for client side caching
target_nodes_specified = False
is_default_node = False
target_nodes = None
Expand Down Expand Up @@ -1119,6 +1125,7 @@ def _execute_command(self, target_node, *args, **kwargs):
"""
Send a command to a node in the cluster
"""
keys = kwargs.pop("keys", None)
command = args[0]
redis_node = None
connection = None
Expand Down Expand Up @@ -1147,14 +1154,18 @@ def _execute_command(self, target_node, *args, **kwargs):
connection.send_command("ASKING")
redis_node.parse_response(connection, "ASKING", **kwargs)
asking = False

connection.send_command(*args)
response = redis_node.parse_response(connection, command, **kwargs)
if command in self.cluster_response_callbacks:
response = self.cluster_response_callbacks[command](
response, **kwargs
)
return response
response_from_cache = connection._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
connection.send_command(*args)
response = redis_node.parse_response(connection, command, **kwargs)
if command in self.cluster_response_callbacks:
response = self.cluster_response_callbacks[command](
response, **kwargs
)
connection._add_to_local_cache(args, response, keys)
return response
except AuthenticationError:
raise
except (ConnectionError, TimeoutError) as e:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _get_client(
redis_url = request.config.getoption("--redis-url")
else:
redis_url = from_url
if "protocol" not in redis_url:
if "protocol" not in redis_url and kwargs.get("protocol") is None:
kwargs["protocol"] = request.config.getoption("--protocol")

cluster_mode = REDIS_INFO["cluster_enabled"]
Expand Down
3 changes: 1 addition & 2 deletions tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ async def client_factory(
url: str = request.config.getoption("--redis-url"),
cls=redis.Redis,
flushdb=True,
protocol=request.config.getoption("--protocol"),
**kwargs,
):
if "protocol" not in url:
if "protocol" not in url and kwargs.get("protocol") is None:
kwargs["protocol"] = request.config.getoption("--protocol")

cluster_mode = REDIS_INFO["cluster_enabled"]
Expand Down

0 comments on commit c7a13ae

Please sign in to comment.