Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replica-only read mode to cluster and asyncio cluster #3182

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion redis/__init__.py
Expand Up @@ -3,7 +3,7 @@
from redis import asyncio # noqa
from redis.backoff import default_backoff
from redis.client import Redis, StrictRedis
from redis.cluster import RedisCluster
from redis.cluster import RedisCluster, ReadFromReplicasMode
from redis.connection import (
BlockingConnectionPool,
Connection,
Expand Down
46 changes: 23 additions & 23 deletions redis/asyncio/cluster.py
Expand Up @@ -44,6 +44,7 @@
SLOT_ID,
AbstractRedisCluster,
LoadBalancer,
ReadFromReplicasMode,
block_pipeline_command,
get_node_name,
parse_cluster_slots,
Expand Down Expand Up @@ -136,9 +137,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
| See:
https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
:param read_from_replicas:
| Enable read from replicas in READONLY mode. You can read possibly stale data.
When set to true, read commands will be assigned between the primary and
its replications in a Round-Robin manner.
| Enable read from replicas in READONLY mode. You can read possibly
stale data.
| When set to true, read commands will be assigned between the
primary and its replications in a Round-Robin manner. When set to
ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from
the replicas
:param reinitialize_steps:
| Specifies the number of MOVED errors that need to occur before reinitializing
the whole cluster topology. If a MOVED error occurs and the cluster does not
Expand Down Expand Up @@ -238,7 +242,7 @@ def __init__(
# Cluster related kwargs
startup_nodes: Optional[List["ClusterNode"]] = None,
require_full_coverage: bool = True,
read_from_replicas: bool = False,
read_from_replicas: bool|ReadFromReplicasMode = False,
reinitialize_steps: int = 5,
cluster_error_retry_attempts: int = 3,
connection_error_retry_attempts: int = 3,
Expand Down Expand Up @@ -350,7 +354,9 @@ def __init__(
}
)

if read_from_replicas:
self.read_from_replicas_mode = ReadFromReplicasMode.from_parameters(read_from_replicas)

if self.read_from_replicas_mode != ReadFromReplicasMode.ReadFromPrimary:
# Call our on_connect function to configure READONLY mode
kwargs["redis_connect_func"] = self.on_connect

Expand Down Expand Up @@ -392,7 +398,6 @@ def __init__(
address_remap=address_remap,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
self.cluster_error_retry_attempts = cluster_error_retry_attempts
self.connection_error_retry_attempts = connection_error_retry_attempts
Expand Down Expand Up @@ -610,7 +615,7 @@ async def _determine_nodes(
return [
self.nodes_manager.get_node_from_slot(
await self._determine_slot(command, *args),
self.read_from_replicas and command in READ_COMMANDS,
self.read_from_replicas_mode.get_replica_mode_for_command(command)
)
]

Expand Down Expand Up @@ -791,7 +796,7 @@ async def _execute_command(
# refresh the target node
slot = await self._determine_slot(*args)
target_node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and args[0] in READ_COMMANDS
slot, self.read_from_replicas_mode.get_replica_mode_for_command(args[0])
)
moved = False

Expand Down Expand Up @@ -1215,25 +1220,20 @@ def _update_moved_slots(self) -> None:
self._moved_exception = None

def get_node_from_slot(
self, slot: int, read_from_replicas: bool = False
self, slot: int, read_from_replicas_mode: ReadFromReplicasMode
) -> "ClusterNode":
"""
Gets a node that servers this hash slot
"""
if self._moved_exception:
self._update_moved_slots()

try:
if read_from_replicas:
# get the server index in a Round-Robin manner
primary_name = self.slots_cache[slot][0].name
node_idx = self.read_load_balancer.get_server_index(
primary_name, len(self.slots_cache[slot])
)
return self.slots_cache[slot][node_idx]
return self.slots_cache[slot][0]
except (IndexError, TypeError):
raise SlotNotCoveredError(
f'Slot "{slot}" not covered by the cluster. '
f'"require_full_coverage={self.require_full_coverage}"'

return self.read_load_balancer.get_node_from_slot(
slot,
self.slots_cache.get(slot, None),
read_from_replicas_mode,
)


def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
return [
Expand Down
117 changes: 74 additions & 43 deletions redis/cluster.py
@@ -1,3 +1,4 @@
from enum import Enum
import random
import socket
import sys
Expand Down Expand Up @@ -190,6 +191,27 @@ def cleanup_kwargs(**kwargs):

return connection_kwargs

class ReadFromReplicasMode(Enum):
ReadFromPrimary = 0
ReadFromPrimaryAndReplica = 1
ReadFromReplicaOnly = 2

@staticmethod
def from_parameters(input: bool|"ReadFromReplicasMode"):
if input == True:
return ReadFromReplicasMode.ReadFromPrimaryAndReplica
elif input == False:
return ReadFromReplicasMode.ReadFromPrimary
if not input in ReadFromReplicasMode:
raise RedisClusterException("Argument 'read_from_replicas' must be a boolean or a value of ReadFromReplicasMode")
return input

def get_replica_mode_for_command(self, command: str):
if self == ReadFromReplicasMode.ReadFromPrimary:
return ReadFromReplicasMode.ReadFromPrimary
if not command in READ_COMMANDS:
return ReadFromReplicasMode.ReadFromPrimary
return self

class ClusterParser(DefaultParser):
EXCEPTION_CLASSES = dict_merge(
Expand Down Expand Up @@ -503,7 +525,7 @@ def __init__(
retry: Optional["Retry"] = None,
require_full_coverage: bool = False,
reinitialize_steps: int = 5,
read_from_replicas: bool = False,
read_from_replicas: bool|ReadFromReplicasMode = False,
dynamic_startup_nodes: bool = True,
url: Optional[str] = None,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
Expand Down Expand Up @@ -532,7 +554,9 @@ def __init__(
Enable read from replicas in READONLY mode. You can read possibly
stale data.
When set to true, read commands will be assigned between the
primary and its replications in a Round-Robin manner.
primary and its replications in a Round-Robin manner. When set to
ReadFromReplicasMode.ReadFromReplicaOnly, it will only read from
the replicas
:param dynamic_startup_nodes:
Set the RedisCluster's startup nodes to all of the discovered nodes.
If true (default value), the cluster's discovered nodes will be used to
Expand Down Expand Up @@ -633,7 +657,7 @@ def __init__(
self.cluster_error_retry_attempts = cluster_error_retry_attempts
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
self.node_flags = self.__class__.NODE_FLAGS.copy()
self.read_from_replicas = read_from_replicas
self.read_from_replicas_mode = ReadFromReplicasMode.from_parameters(read_from_replicas)
self.reinitialize_counter = 0
self.reinitialize_steps = reinitialize_steps
self.nodes_manager = NodesManager(
Expand Down Expand Up @@ -678,7 +702,7 @@ def on_connect(self, connection):
connection.set_parser(ClusterParser)
connection.on_connect()

if self.read_from_replicas:
if self.read_from_replicas != ReadFromReplicasMode.ReadFromPrimary:
# Sending READONLY command to server to configure connection as
# readonly. Since each cluster node may change its server type due
# to a failover, we should establish a READONLY connection
Expand Down Expand Up @@ -706,6 +730,13 @@ def get_primaries(self):

def get_replicas(self):
return self.nodes_manager.get_nodes_by_server_type(REPLICA)

def get_read_from_replica_mode_for_command(self, command: str):
if (
(self.read_from_replicas_mode == ReadFromReplicasMode.ReadFromPrimary) or
(not command in READ_COMMANDS)):
return ReadFromReplicasMode.ReadFromPrimary
return self.read_from_replicas_mode

def get_random_node(self):
return random.choice(list(self.nodes_manager.nodes_cache.values()))
Expand Down Expand Up @@ -804,7 +835,7 @@ def pipeline(self, transaction=None, shard_hint=None):
result_callbacks=self.result_callbacks,
cluster_response_callbacks=self.cluster_response_callbacks,
cluster_error_retry_attempts=self.cluster_error_retry_attempts,
read_from_replicas=self.read_from_replicas,
read_from_replicas_mode=self.read_from_replicas_mode,
reinitialize_steps=self.reinitialize_steps,
lock=self._lock,
)
Expand Down Expand Up @@ -922,7 +953,7 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
slot, self.read_from_replicas_mode.get_replica_mode_for_command(command)
)
return [node]

Expand Down Expand Up @@ -1144,7 +1175,7 @@ def _execute_command(self, target_node, *args, **kwargs):
# refresh the target node
slot = self.determine_slot(*args)
target_node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
slot, self.read_from_replicas_mode.get_replica_mode_for_command(command)
)
moved = False

Expand Down Expand Up @@ -1293,7 +1324,6 @@ def __del__(self):
if self.redis_connection is not None:
self.redis_connection.close()


class LoadBalancer:
"""
Round-Robin Load Balancing
Expand All @@ -1302,11 +1332,30 @@ class LoadBalancer:
def __init__(self, start_index: int = 0) -> None:
self.primary_to_idx = {}
self.start_index = start_index

def get_server_index(self, primary: str, list_size: int) -> int:
server_index = self.primary_to_idx.setdefault(primary, self.start_index)
# Update the index
self.primary_to_idx[primary] = (server_index + 1) % list_size

def get_node_from_slot(self, slot_index: int, slot_nodes: list[ClusterNode] | None, read_from_replicas_mode: ReadFromReplicasMode):
if slot_nodes is None or len(slot_nodes) == 0:
raise SlotNotCoveredError(
f'Slot "{slot_index}" not covered by the cluster. '
)
if read_from_replicas_mode == ReadFromReplicasMode.ReadFromPrimary:
node_idx = 0
else:
skip_primary = read_from_replicas_mode == ReadFromReplicasMode.ReadFromReplicaOnly
# get the server index in a Round-Robin manner
primary_name = slot_nodes[0].name
node_idx = self.read_load_balancer.get_server_index(
primary_name, len(slot_nodes), skip_primary
)
return slot_nodes[node_idx]

def get_server_index(self, primary: str, list_size: int, skip_primary:bool) -> int:
# default to -1 if not found, so after incrementing it will be 0
server_index = (self.primary_to_idx.get(primary, -1) + 1) % list_size
# If we skip primary, skip the zero-index node.
if skip_primary and server_index == 0 and list_size > 1:
server_index = server_index + 1
self.primary_to_idx[primary] = server_index
return server_index

def reset(self) -> None:
Expand Down Expand Up @@ -1401,41 +1450,23 @@ def _update_moved_slots(self):
# Reset moved_exception
self._moved_exception = None

def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None):
def get_node_from_slot(
self, slot: int, read_from_replicas_mode: ReadFromReplicasMode
) -> "ClusterNode":
"""
Gets a node that servers this hash slot
"""
if self._moved_exception:
with self._lock:
if self._moved_exception:
self._update_moved_slots()

if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0:
raise SlotNotCoveredError(
f'Slot "{slot}" not covered by the cluster. '
f'"require_full_coverage={self._require_full_coverage}"'
)

if read_from_replicas is True:
# get the server index in a Round-Robin manner
primary_name = self.slots_cache[slot][0].name
node_idx = self.read_load_balancer.get_server_index(
primary_name, len(self.slots_cache[slot])

return self.read_load_balancer.get_node_from_slot(
slot,
self.slots_cache.get(slot, None),
read_from_replicas_mode,
)
elif (
server_type is None
or server_type == PRIMARY
or len(self.slots_cache[slot]) == 1
):
# return a primary
node_idx = 0
else:
# return a replica
# randomly choose one of the replicas
node_idx = random.randint(1, len(self.slots_cache[slot]) - 1)

return self.slots_cache[slot][node_idx]


def get_nodes_by_server_type(self, server_type):
"""
Get all nodes with the specified server type
Expand Down Expand Up @@ -1775,7 +1806,7 @@ def execute_command(self, *args):
channel = args[1]
slot = self.cluster.keyslot(channel)
node = self.cluster.nodes_manager.get_node_from_slot(
slot, self.cluster.read_from_replicas
slot, self.cluster.read_from_replicas_mode
)
else:
# Get a random node
Expand Down Expand Up @@ -1915,7 +1946,7 @@ def __init__(
result_callbacks: Optional[Dict[str, Callable]] = None,
cluster_response_callbacks: Optional[Dict[str, Callable]] = None,
startup_nodes: Optional[List["ClusterNode"]] = None,
read_from_replicas: bool = False,
read_from_replicas_mode: ReadFromReplicasMode = ReadFromReplicasMode.ReadFromPrimary,
cluster_error_retry_attempts: int = 3,
reinitialize_steps: int = 5,
lock=None,
Expand All @@ -1930,7 +1961,7 @@ def __init__(
result_callbacks or self.__class__.RESULT_CALLBACKS.copy()
)
self.startup_nodes = startup_nodes if startup_nodes else []
self.read_from_replicas = read_from_replicas
self.read_from_replicas_mode = read_from_replicas_mode
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
self.cluster_response_callbacks = cluster_response_callbacks
self.cluster_error_retry_attempts = cluster_error_retry_attempts
Expand Down