Skip to content

Commit

Permalink
add cluster "host_port_remap" feature
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Apr 15, 2023
1 parent 45cabb5 commit caafb98
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import warnings
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -250,6 +252,7 @@ def __init__(
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -337,7 +340,12 @@ def __init__(
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))

self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
self.nodes_manager = NodesManager(
startup_nodes,
require_full_coverage,
kwargs,
host_port_remap=host_port_remap,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
Expand Down Expand Up @@ -1044,17 +1052,20 @@ class NodesManager:
"require_full_coverage",
"slots_cache",
"startup_nodes",
"host_port_remap",
)

def __init__(
self,
startup_nodes: List["ClusterNode"],
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
self.connection_kwargs = connection_kwargs
self.host_port_remap = host_port_remap

self.default_node: "ClusterNode" = None
self.nodes_cache: Dict[str, "ClusterNode"] = {}
Expand Down Expand Up @@ -1213,6 +1224,7 @@ async def initialize(self) -> None:
if host == "":
host = startup_node.host
port = int(primary_node[1])
host, port = self.remap_host_port(host, port)

target_node = tmp_nodes_cache.get(get_node_name(host, port))
if not target_node:
Expand All @@ -1231,6 +1243,7 @@ async def initialize(self) -> None:
for replica_node in replica_nodes:
host = replica_node[0]
port = replica_node[1]
host, port = self.remap_host_port(host, port)

target_replica_node = tmp_nodes_cache.get(
get_node_name(host, port)
Expand Down Expand Up @@ -1304,6 +1317,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
)
)

def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
"""
Remap the host and port returned from the cluster to a different
internal value. Useful if the client is not connecting directly
to the cluster.
"""
if self.host_port_remap:
return self.host_port_remap(host, port)
return host, port


class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
"""
Expand Down

0 comments on commit caafb98

Please sign in to comment.