Skip to content

Commit

Permalink
Add replica-only read mode to cluster and asyncio cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
jhenkens committed Mar 13, 2024
1 parent ddff7b5 commit 25273b6
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 67 deletions.
2 changes: 1 addition & 1 deletion redis/__init__.py
Expand Up @@ -3,7 +3,7 @@
from redis import asyncio # noqa
from redis.backoff import default_backoff
from redis.client import Redis, StrictRedis
from redis.cluster import RedisCluster
from redis.cluster import RedisCluster, ReadFromReplicasMode
from redis.connection import (
BlockingConnectionPool,
Connection,
Expand Down
46 changes: 23 additions & 23 deletions redis/asyncio/cluster.py
Expand Up @@ -44,6 +44,7 @@
SLOT_ID,
AbstractRedisCluster,
LoadBalancer,
ReadFromReplicasMode,
block_pipeline_command,
get_node_name,
parse_cluster_slots,
Expand Down Expand Up @@ -136,9 +137,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
| See:
https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
:param read_from_replicas:
| Enable read from replicas in READONLY mode. You can read possibly stale data.
When set to true, read commands will be assigned between the primary and
its replications in a Round-Robin manner.
| Enable read from replicas in READONLY mode. You can read possibly
stale data.
| When set to true, read commands will be assigned between the
primary and its replications in a Round-Robin manner. When set to
ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from
the replicas
:param reinitialize_steps:
| Specifies the number of MOVED errors that need to occur before reinitializing
the whole cluster topology. If a MOVED error occurs and the cluster does not
Expand Down Expand Up @@ -238,7 +242,7 @@ def __init__(
# Cluster related kwargs
startup_nodes: Optional[List["ClusterNode"]] = None,
require_full_coverage: bool = True,
read_from_replicas: bool = False,
read_from_replicas: bool|ReadFromReplicasMode = False,
reinitialize_steps: int = 5,
cluster_error_retry_attempts: int = 3,
connection_error_retry_attempts: int = 3,
Expand Down Expand Up @@ -350,7 +354,9 @@ def __init__(
}
)

if read_from_replicas:
self.read_from_replicas_mode = ReadFromReplicasMode.from_parameters(read_from_replicas)

if self.read_from_replicas_mode != ReadFromReplicasMode.ReadFromPrimary:
# Call our on_connect function to configure READONLY mode
kwargs["redis_connect_func"] = self.on_connect

Expand Down Expand Up @@ -392,7 +398,6 @@ def __init__(
address_remap=address_remap,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
self.cluster_error_retry_attempts = cluster_error_retry_attempts
self.connection_error_retry_attempts = connection_error_retry_attempts
Expand Down Expand Up @@ -610,7 +615,7 @@ async def _determine_nodes(
return [
self.nodes_manager.get_node_from_slot(
await self._determine_slot(command, *args),
self.read_from_replicas and command in READ_COMMANDS,
self.read_from_replicas_mode.get_replica_mode_for_command(command)
)
]

Expand Down Expand Up @@ -791,7 +796,7 @@ async def _execute_command(
# refresh the target node
slot = await self._determine_slot(*args)
target_node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and args[0] in READ_COMMANDS
slot, self.read_from_replicas_mode.get_replica_mode_for_command(args[0])
)
moved = False

Expand Down Expand Up @@ -1215,25 +1220,20 @@ def _update_moved_slots(self) -> None:
self._moved_exception = None

def get_node_from_slot(
self, slot: int, read_from_replicas: bool = False
self, slot: int, read_from_replicas_mode: ReadFromReplicasMode
) -> "ClusterNode":
"""
Gets a node that servers this hash slot
"""
if self._moved_exception:
self._update_moved_slots()

try:
if read_from_replicas:
# get the server index in a Round-Robin manner
primary_name = self.slots_cache[slot][0].name
node_idx = self.read_load_balancer.get_server_index(
primary_name, len(self.slots_cache[slot])
)
return self.slots_cache[slot][node_idx]
return self.slots_cache[slot][0]
except (IndexError, TypeError):
raise SlotNotCoveredError(
f'Slot "{slot}" not covered by the cluster. '
f'"require_full_coverage={self.require_full_coverage}"'

return self.read_load_balancer.get_node_from_slot(
slot,
self.slots_cache.get(slot, None),
read_from_replicas_mode,
)


def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
return [
Expand Down
117 changes: 74 additions & 43 deletions redis/cluster.py
@@ -1,3 +1,4 @@
from enum import Enum
import random
import socket
import sys
Expand Down Expand Up @@ -190,6 +191,27 @@ def cleanup_kwargs(**kwargs):

return connection_kwargs

class ReadFromReplicasMode(Enum):
ReadFromPrimary = 0
ReadFromPrimaryAndReplica = 1
ReadFromReplicaOnly = 2

@staticmethod
def from_parameters(input: bool|"ReadFromReplicasMode"):
if input == True:
return ReadFromReplicasMode.ReadFromPrimaryAndReplica
elif input == False:
return ReadFromReplicasMode.ReadFromPrimary
if not input in ReadFromReplicasMode:
raise RedisClusterException("Argument 'read_from_replicas' must be a boolean or a value of ReadFromReplicasMode")
return input

def get_replica_mode_for_command(self, command: str):
if self == ReadFromReplicasMode.ReadFromPrimary:
return ReadFromReplicasMode.ReadFromPrimary
if not command in READ_COMMANDS:
return ReadFromReplicasMode.ReadFromPrimary
return self

class ClusterParser(DefaultParser):
EXCEPTION_CLASSES = dict_merge(
Expand Down Expand Up @@ -503,7 +525,7 @@ def __init__(
retry: Optional["Retry"] = None,
require_full_coverage: bool = False,
reinitialize_steps: int = 5,
read_from_replicas: bool = False,
read_from_replicas: bool|ReadFromReplicasMode = False,
dynamic_startup_nodes: bool = True,
url: Optional[str] = None,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
Expand Down Expand Up @@ -532,7 +554,9 @@ def __init__(
Enable read from replicas in READONLY mode. You can read possibly
stale data.
When set to true, read commands will be assigned between the
primary and its replications in a Round-Robin manner.
primary and its replications in a Round-Robin manner. When set to
ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from
the replicas
:param dynamic_startup_nodes:
Set the RedisCluster's startup nodes to all of the discovered nodes.
If true (default value), the cluster's discovered nodes will be used to
Expand Down Expand Up @@ -633,7 +657,7 @@ def __init__(
self.cluster_error_retry_attempts = cluster_error_retry_attempts
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
self.node_flags = self.__class__.NODE_FLAGS.copy()
self.read_from_replicas = read_from_replicas
self.read_from_replicas_mode = ReadFromReplicasMode.from_parameters(read_from_replicas)
self.reinitialize_counter = 0
self.reinitialize_steps = reinitialize_steps
self.nodes_manager = NodesManager(
Expand Down Expand Up @@ -678,7 +702,7 @@ def on_connect(self, connection):
connection.set_parser(ClusterParser)
connection.on_connect()

if self.read_from_replicas:
if self.read_from_replicas != ReadFromReplicasMode.ReadFromPrimary:
# Sending READONLY command to server to configure connection as
# readonly. Since each cluster node may change its server type due
# to a failover, we should establish a READONLY connection
Expand Down Expand Up @@ -706,6 +730,13 @@ def get_primaries(self):

def get_replicas(self):
return self.nodes_manager.get_nodes_by_server_type(REPLICA)

def get_read_from_replica_mode_for_command(self, command: str):
if (
(self.read_from_replicas_mode == ReadFromReplicasMode.ReadFromPrimary) or
(not command in READ_COMMANDS)):
return ReadFromReplicasMode.ReadFromPrimary
return self.read_from_replicas_mode

def get_random_node(self):
return random.choice(list(self.nodes_manager.nodes_cache.values()))
Expand Down Expand Up @@ -804,7 +835,7 @@ def pipeline(self, transaction=None, shard_hint=None):
result_callbacks=self.result_callbacks,
cluster_response_callbacks=self.cluster_response_callbacks,
cluster_error_retry_attempts=self.cluster_error_retry_attempts,
read_from_replicas=self.read_from_replicas,
read_from_replicas_mode=self.read_from_replicas_mode,
reinitialize_steps=self.reinitialize_steps,
lock=self._lock,
)
Expand Down Expand Up @@ -922,7 +953,7 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
slot, self.read_from_replicas_mode.get_replica_mode_for_command(command)
)
return [node]

Expand Down Expand Up @@ -1144,7 +1175,7 @@ def _execute_command(self, target_node, *args, **kwargs):
# refresh the target node
slot = self.determine_slot(*args)
target_node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
slot, self.read_from_replicas_mode.get_replica_mode_for_command(command)
)
moved = False

Expand Down Expand Up @@ -1293,7 +1324,6 @@ def __del__(self):
if self.redis_connection is not None:
self.redis_connection.close()


class LoadBalancer:
"""
Round-Robin Load Balancing
Expand All @@ -1302,11 +1332,30 @@ class LoadBalancer:
def __init__(self, start_index: int = 0) -> None:
self.primary_to_idx = {}
self.start_index = start_index

def get_server_index(self, primary: str, list_size: int) -> int:
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
# Update the index
self.primary_to_idx[primary] = (server_index + 1) % list_size

def get_node_from_slot(self, slot_index: int, slot_nodes: list[ClusterNode] | None, read_from_replicas_mode: ReadFromReplicasMode):
if slot_nodes is None or len(slot_nodes) == 0:
raise SlotNotCoveredError(
f'Slot "{slot_index}" not covered by the cluster. '
)
if read_from_replicas_mode == ReadFromReplicasMode.ReadFromPrimary:
node_idx = 0
else:
skip_primary = read_from_replicas_mode == ReadFromReplicasMode.ReadFromReplicaOnly
# get the server index in a Round-Robin manner
primary_name = slot_nodes[0].name
node_idx = self.read_load_balancer.get_server_index(
primary_name, len(slot_nodes), skip_primary
)
return slot_nodes[node_idx]

def get_server_index(self, primary: str, list_size: int, skip_primary:bool) -> int:
# default to -1 if not found, so after incrementing it will be 0
server_index = (self.primary_to_idx.get(primary, -1) + 1) % list_size
# If we skip primary, skip the zero-index node.
if skip_primary and server_index == 0 and list_size > 1:
server_index = server_index + 1
self.primary_to_idx[primary] = server_index
return server_index

def reset(self) -> None:
Expand Down Expand Up @@ -1401,41 +1450,23 @@ def _update_moved_slots(self):
# Reset moved_exception
self._moved_exception = None

def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None):
def get_node_from_slot(
self, slot: int, read_from_replicas_mode: ReadFromReplicasMode
) -> "ClusterNode":
"""
Gets a node that servers this hash slot
"""
if self._moved_exception:
with self._lock:
if self._moved_exception:
self._update_moved_slots()

if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0:
raise SlotNotCoveredError(
f'Slot "{slot}" not covered by the cluster. '
f'"require_full_coverage={self._require_full_coverage}"'
)

if read_from_replicas is True:
# get the server index in a Round-Robin manner
primary_name = self.slots_cache[slot][0].name
node_idx = self.read_load_balancer.get_server_index(
primary_name, len(self.slots_cache[slot])

return self.read_load_balancer.get_node_from_slot(
slot,
self.slots_cache.get(slot, None),
read_from_replicas_mode,
)
elif (
server_type is None
or server_type == PRIMARY
or len(self.slots_cache[slot]) == 1
):
# return a primary
node_idx = 0
else:
# return a replica
# randomly choose one of the replicas
node_idx = random.randint(1, len(self.slots_cache[slot]) - 1)

return self.slots_cache[slot][node_idx]


def get_nodes_by_server_type(self, server_type):
"""
Get all nodes with the specified server type
Expand Down Expand Up @@ -1775,7 +1806,7 @@ def execute_command(self, *args):
channel = args[1]
slot = self.cluster.keyslot(channel)
node = self.cluster.nodes_manager.get_node_from_slot(
slot, self.cluster.read_from_replicas
slot, self.cluster.read_from_replicas_mode
)
else:
# Get a random node
Expand Down Expand Up @@ -1915,7 +1946,7 @@ def __init__(
result_callbacks: Optional[Dict[str, Callable]] = None,
cluster_response_callbacks: Optional[Dict[str, Callable]] = None,
startup_nodes: Optional[List["ClusterNode"]] = None,
read_from_replicas: bool = False,
read_from_replicas_mode: ReadFromReplicasMode = ReadFromReplicasMode.ReadFromPrimary,
cluster_error_retry_attempts: int = 3,
reinitialize_steps: int = 5,
lock=None,
Expand All @@ -1930,7 +1961,7 @@ def __init__(
result_callbacks or self.__class__.RESULT_CALLBACKS.copy()
)
self.startup_nodes = startup_nodes if startup_nodes else []
self.read_from_replicas = read_from_replicas
self.read_from_replicas_mode = read_from_replicas_mode
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
self.cluster_response_callbacks = cluster_response_callbacks
self.cluster_error_retry_attempts = cluster_error_retry_attempts
Expand Down

0 comments on commit 25273b6

Please sign in to comment.