From f9f2a397d636d4b2f10c0159f7273bf3c57eb81f Mon Sep 17 00:00:00 2001 From: Johan Henkens Date: Tue, 28 May 2024 11:49:49 -0700 Subject: [PATCH] Fix get_node_from_slot during resharding --- redis/asyncio/cluster.py | 23 ++++++-------- redis/cluster.py | 66 +++++++++++++++++++--------------------- 2 files changed, 40 insertions(+), 49 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 719c2d228..0d4accb38 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1244,23 +1244,18 @@ def _update_moved_slots(self) -> None: def get_node_from_slot( self, slot: int, read_from_replicas: bool = False ) -> "ClusterNode": + """ + Gets a node that serves this hash slot + """ if self._moved_exception: self._update_moved_slots() - - try: - if read_from_replicas: - # get the server index in a Round-Robin manner - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) - ) - return self.slots_cache[slot][node_idx] - return self.slots_cache[slot][0] - except (IndexError, TypeError): - raise SlotNotCoveredError( - f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self.require_full_coverage}"' + + return self.read_load_balancer.get_node_from_slot( + slot, + self.slots_cache.get(slot, None), + read_from_replicas, ) + def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: return [ diff --git a/redis/cluster.py b/redis/cluster.py index 144844ec8..1b81e0875 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1324,18 +1324,32 @@ class LoadBalancer: Round-Robin Load Balancing """ - def __init__(self, start_index: int = 0) -> None: - self.primary_to_idx = {} - self.start_index = start_index + def __init__(self) -> None: + self.primary_name_to_last_used_index:dict[str,int] = {} + + def get_node_from_slot(self, slot_index: int, slot_nodes: list[ClusterNode] | None, read_from_replicas: bool) -> ClusterNode: + if slot_nodes is None or len(slot_nodes) == 0: + raise SlotNotCoveredError( + f'Slot "{slot_index}" not covered by the cluster. ' + ) + if not read_from_replicas: + node_idx = 0 + else: + primary_name = slot_nodes[0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(slot_nodes) + ) + return slot_nodes[node_idx] - def get_server_index(self, primary: str, list_size: int) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) + def get_server_index(self, primary: str, list_size: int) -> int: + # default to -1 if not found, so after incrementing it will be 0 + server_index = (self.primary_name_to_last_used_index.get(primary, -1) + 1) % list_size # Update the index - self.primary_to_idx[primary] = (server_index + 1) % list_size + self.primary_name_to_last_used_index[primary] = server_index return server_index def reset(self) -> None: - self.primary_to_idx.clear() + self.primary_name_to_last_used_index.clear() class NodesManager: @@ -1426,41 +1440,23 @@ def _update_moved_slots(self): # Reset moved_exception self._moved_exception = None - def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + def get_node_from_slot( + self, slot: int, read_from_replicas: bool + ) -> "ClusterNode": """ - Gets a node that servers this hash slot + Gets a node that serves this hash slot """ if self._moved_exception: with self._lock: if self._moved_exception: self._update_moved_slots() - - if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: - raise SlotNotCoveredError( - f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self._require_full_coverage}"' + + return self.read_load_balancer.get_node_from_slot( + slot, + self.slots_cache.get(slot, None), + read_from_replicas, ) - - if read_from_replicas is True: - # get the server index in a Round-Robin manner - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) - ) - elif ( - server_type is None - or server_type == PRIMARY - or len(self.slots_cache[slot]) == 1 - ): - # return a primary - node_idx = 0 - else: - # return a replica - # randomly choose one of the replicas - node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) - - return self.slots_cache[slot][node_idx] - + def get_nodes_by_server_type(self, server_type): """ Get all nodes with the specified server type