diff --git a/redis/__init__.py b/redis/__init__.py index 495d2d99b..4e54ff56e 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -3,7 +3,7 @@ from redis import asyncio # noqa from redis.backoff import default_backoff from redis.client import Redis, StrictRedis -from redis.cluster import RedisCluster +from redis.cluster import RedisCluster, ReadFromReplicasMode from redis.connection import ( BlockingConnectionPool, Connection, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 11c423b84..82ef6addb 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -44,6 +44,7 @@ SLOT_ID, AbstractRedisCluster, LoadBalancer, + ReadFromReplicasMode, block_pipeline_command, get_node_name, parse_cluster_slots, @@ -136,9 +137,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | See: https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters :param read_from_replicas: - | Enable read from replicas in READONLY mode. You can read possibly stale data. - When set to true, read commands will be assigned between the primary and - its replications in a Round-Robin manner. + | Enable read from replicas in READONLY mode. You can read possibly + stale data. + | When set to true, read commands will be assigned between the + primary and its replications in a Round-Robin manner. When set to + ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from + the replicas :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -238,7 +242,7 @@ def __init__( # Cluster related kwargs startup_nodes: Optional[List["ClusterNode"]] = None, require_full_coverage: bool = True, - read_from_replicas: bool = False, + read_from_replicas: bool|ReadFromReplicasMode = False, reinitialize_steps: int = 5, cluster_error_retry_attempts: int = 3, connection_error_retry_attempts: int = 3, @@ -350,7 +354,9 @@ def __init__( } ) - if read_from_replicas: + self.read_from_replicas_mode = ReadFromReplicasMode.from_parameters(read_from_replicas) + + if self.read_from_replicas_mode != ReadFromReplicasMode.ReadFromPrimary: # Call our on_connect function to configure READONLY mode kwargs["redis_connect_func"] = self.on_connect @@ -392,7 +398,6 @@ def __init__( address_remap=address_remap, ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) - self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps self.cluster_error_retry_attempts = cluster_error_retry_attempts self.connection_error_retry_attempts = connection_error_retry_attempts @@ -610,7 +615,7 @@ async def _determine_nodes( return [ self.nodes_manager.get_node_from_slot( await self._determine_slot(command, *args), - self.read_from_replicas and command in READ_COMMANDS, + self.read_from_replicas_mode.get_replica_mode_for_command(command) ) ] @@ -791,7 +796,7 @@ async def _execute_command( # refresh the target node slot = await self._determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and args[0] in READ_COMMANDS + slot, self.read_from_replicas_mode.get_replica_mode_for_command(args[0]) ) moved = False @@ -1215,25 +1220,20 @@ def _update_moved_slots(self) -> None: self._moved_exception = None def get_node_from_slot( - self, slot: int, read_from_replicas: bool = False + self, slot: int, read_from_replicas_mode: ReadFromReplicasMode ) -> "ClusterNode": + """ + Gets a node that servers this hash slot + """ if self._moved_exception: self._update_moved_slots() - - try: - if read_from_replicas: - # get the server index in a Round-Robin manner - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) - ) - return self.slots_cache[slot][node_idx] - return self.slots_cache[slot][0] - except (IndexError, TypeError): - raise SlotNotCoveredError( - f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self.require_full_coverage}"' + + return self.read_load_balancer.get_node_from_slot( + slot, + self.slots_cache.get(slot, None), + read_from_replicas_mode, ) + def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: return [ diff --git a/redis/cluster.py b/redis/cluster.py index a9213f423..37404a86d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1,3 +1,4 @@ +from enum import Enum import random import socket import sys @@ -190,6 +191,27 @@ def cleanup_kwargs(**kwargs): return connection_kwargs +class ReadFromReplicasMode(Enum): + ReadFromPrimary = 0 + ReadFromPrimaryAndReplica = 1 + ReadFromReplicaOnly = 2 + + @staticmethod + def from_parameters(input: bool|"ReadFromReplicasMode"): + if input == True: + return ReadFromReplicasMode.ReadFromPrimaryAndReplica + elif input == False: + return ReadFromReplicasMode.ReadFromPrimary + if not input in ReadFromReplicasMode: + raise RedisClusterException("Argument 'read_from_replicas' must be a boolean or a value of ReadFromReplicasMode") + return input + + def get_replica_mode_for_command(self, command: str): + if self == ReadFromReplicasMode.ReadFromPrimary: + return ReadFromReplicasMode.ReadFromPrimary + if not command in READ_COMMANDS: + return ReadFromReplicasMode.ReadFromPrimary + return self class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( @@ -503,7 +525,7 @@ def __init__( retry: Optional["Retry"] = None, require_full_coverage: bool = False, reinitialize_steps: int = 5, - read_from_replicas: bool = False, + read_from_replicas: bool|ReadFromReplicasMode = False, dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, @@ -532,7 +554,9 @@ def __init__( Enable read from replicas in READONLY mode. You can read possibly stale data. When set to true, read commands will be assigned between the - primary and its replications in a Round-Robin manner. + primary and its replications in a Round-Robin manner. When set to + ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from + the replicas :param dynamic_startup_nodes: Set the RedisCluster's startup nodes to all of the discovered nodes. If true (default value), the cluster's discovered nodes will be used to @@ -633,7 +657,7 @@ def __init__( self.cluster_error_retry_attempts = cluster_error_retry_attempts self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() - self.read_from_replicas = read_from_replicas + self.read_from_replicas_mode = ReadFromReplicasMode.from_parameters(read_from_replicas) self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps self.nodes_manager = NodesManager( @@ -678,7 +702,7 @@ def on_connect(self, connection): connection.set_parser(ClusterParser) connection.on_connect() - if self.read_from_replicas: + if self.read_from_replicas != ReadFromReplicasMode.ReadFromPrimary: # Sending READONLY command to server to configure connection as # readonly. Since each cluster node may change its server type due # to a failover, we should establish a READONLY connection @@ -706,6 +730,13 @@ def get_primaries(self): def get_replicas(self): return self.nodes_manager.get_nodes_by_server_type(REPLICA) + + def get_read_from_replica_mode_for_command(self, command: str): + if ( + (self.read_from_replicas_mode == ReadFromReplicasMode.ReadFromPrimary) or + (not command in READ_COMMANDS)): + return ReadFromReplicasMode.ReadFromPrimary + return self.read_from_replicas_mode def get_random_node(self): return random.choice(list(self.nodes_manager.nodes_cache.values())) @@ -804,7 +835,7 @@ def pipeline(self, transaction=None, shard_hint=None): result_callbacks=self.result_callbacks, cluster_response_callbacks=self.cluster_response_callbacks, cluster_error_retry_attempts=self.cluster_error_retry_attempts, - read_from_replicas=self.read_from_replicas, + read_from_replicas_mode=self.read_from_replicas_mode, reinitialize_steps=self.reinitialize_steps, lock=self._lock, ) @@ -922,7 +953,7 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: # get the node that holds the key's slot slot = self.determine_slot(*args) node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, self.read_from_replicas_mode.get_replica_mode_for_command(command) ) return [node] @@ -1144,7 +1175,7 @@ def _execute_command(self, target_node, *args, **kwargs): # refresh the target node slot = self.determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, self.read_from_replicas_mode.get_replica_mode_for_command(command) ) moved = False @@ -1293,7 +1324,6 @@ def __del__(self): if self.redis_connection is not None: self.redis_connection.close() - class LoadBalancer: """ Round-Robin Load Balancing @@ -1302,11 +1332,30 @@ class LoadBalancer: def __init__(self, start_index: int = 0) -> None: self.primary_to_idx = {} self.start_index = start_index - - def get_server_index(self, primary: str, list_size: int) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - # Update the index - self.primary_to_idx[primary] = (server_index + 1) % list_size + + def get_node_from_slot(self, slot_index: int, slot_nodes: list[ClusterNode] | None, read_from_replicas_mode: ReadFromReplicasMode): + if slot_nodes is None or len(slot_nodes) == 0: + raise SlotNotCoveredError( + f'Slot "{slot_index}" not covered by the cluster. ' + ) + if read_from_replicas_mode == ReadFromReplicasMode.ReadFromPrimary: + node_idx = 0 + else: + skip_primary = read_from_replicas_mode == ReadFromReplicasMode.ReadFromReplicaOnly + # get the server index in a Round-Robin manner + primary_name = slot_nodes[0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(slot_nodes), skip_primary + ) + return slot_nodes[node_idx] + + def get_server_index(self, primary: str, list_size: int, skip_primary:bool) -> int: + # default to -1 if not found, so after incrementing it will be 0 + server_index = (self.primary_to_idx.get(primary, -1) + 1) % list_size + # If we skip primary, skip the zero-index node. + if skip_primary and server_index == 0 and list_size > 1: + server_index = server_index + 1 + self.primary_to_idx[primary] = server_index return server_index def reset(self) -> None: @@ -1401,7 +1450,9 @@ def _update_moved_slots(self): # Reset moved_exception self._moved_exception = None - def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + def get_node_from_slot( + self, slot: int, read_from_replicas_mode: ReadFromReplicasMode + ) -> "ClusterNode": """ Gets a node that servers this hash slot """ @@ -1409,33 +1460,13 @@ def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): with self._lock: if self._moved_exception: self._update_moved_slots() - - if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: - raise SlotNotCoveredError( - f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self._require_full_coverage}"' - ) - - if read_from_replicas is True: - # get the server index in a Round-Robin manner - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) + + return self.read_load_balancer.get_node_from_slot( + slot, + self.slots_cache.get(slot, None), + read_from_replicas_mode, ) - elif ( - server_type is None - or server_type == PRIMARY - or len(self.slots_cache[slot]) == 1 - ): - # return a primary - node_idx = 0 - else: - # return a replica - # randomly choose one of the replicas - node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) - - return self.slots_cache[slot][node_idx] - + def get_nodes_by_server_type(self, server_type): """ Get all nodes with the specified server type @@ -1775,7 +1806,7 @@ def execute_command(self, *args): channel = args[1] slot = self.cluster.keyslot(channel) node = self.cluster.nodes_manager.get_node_from_slot( - slot, self.cluster.read_from_replicas + slot, self.cluster.read_from_replicas_mode ) else: # Get a random node @@ -1915,7 +1946,7 @@ def __init__( result_callbacks: Optional[Dict[str, Callable]] = None, cluster_response_callbacks: Optional[Dict[str, Callable]] = None, startup_nodes: Optional[List["ClusterNode"]] = None, - read_from_replicas: bool = False, + read_from_replicas_mode: ReadFromReplicasMode = ReadFromReplicasMode.ReadFromPrimary, cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 5, lock=None, @@ -1930,7 +1961,7 @@ def __init__( result_callbacks or self.__class__.RESULT_CALLBACKS.copy() ) self.startup_nodes = startup_nodes if startup_nodes else [] - self.read_from_replicas = read_from_replicas + self.read_from_replicas_mode = read_from_replicas_mode self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.cluster_response_callbacks = cluster_response_callbacks self.cluster_error_retry_attempts = cluster_error_retry_attempts