Skip to content

Commit

Permalink
Change host_port_remap into a callable
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Apr 12, 2023
1 parent fe72d7b commit 6f6b6f6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 27 deletions.
23 changes: 5 additions & 18 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
Expand Down Expand Up @@ -251,7 +252,7 @@ def __init__(
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
host_port_remap: List[Dict[str, Any]] = [],
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -1059,7 +1060,7 @@ def __init__(
startup_nodes: List["ClusterNode"],
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
host_port_remap: List[Dict[str, Any]] = [],
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
Expand Down Expand Up @@ -1322,22 +1323,8 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
internal value. Useful if the client is not connecting directly
to the cluster.
"""
for map_entry in self.host_port_remap:
mapped = False
if "from_host" in map_entry:
if host != map_entry["from_host"]:
continue
else:
host = map_entry["to_host"]
mapped = True
if "from_port" in map_entry:
if port != map_entry["from_port"]:
continue
else:
port = map_entry["to_port"]
mapped = True
if mapped:
break
if self.host_port_remap:
return self.host_port_remap(host, port)
return host, port


Expand Down
21 changes: 12 additions & 9 deletions tests/test_asyncio/test_cwe_404.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,37 +181,39 @@ async def test_cluster(request, redis_addr):
remap_base = 7372
n_nodes = 6

remap = []
def remap(host, port):
return host, remap_base + port - cluster_port

proxies = []
for i in range(n_nodes):
port = cluster_port + i
remapped = remap_base + i
remap.append({"from_port": port, "to_port": remapped})
forward_addr = redis_addr[0], port
proxy = DelayProxy(
addr=("127.0.0.1", remapped), redis_addr=forward_addr, delay=0
)
proxies.append(proxy)

# start proxies
await asyncio.gather(*[p.start() for p in proxies])

# helpers to work with all or any proxy
def all_clear():
for p in proxies:
p.send_event.clear()

async def wait_for_send():
async def any_wait():
asyncio.wait(
[p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED
)

@contextlib.contextmanager
def override(delay: int = 0):
def all_override(delay: int = 0):
with contextlib.ExitStack() as stack:
for p in proxies:
stack.enter_context(p.override(delay=delay))
yield

# start proxies
await asyncio.gather(*[p.start() for p in proxies])

with contextlib.closing(
RedisCluster.from_url(f"redis://127.0.0.1:{remap_base}", host_port_remap=remap)
) as r:
Expand All @@ -220,10 +222,10 @@ def override(delay: int = 0):
await r.set("bar", "bar")

all_clear()
with override(delay=delay):
with all_override(delay=delay):
t = asyncio.create_task(r.get("foo"))
# cannot wait on the send event, we don't know which node will be used
await wait_for_send()
await any_wait()
await asyncio.sleep(delay)
t.cancel()
with pytest.raises(asyncio.CancelledError):
Expand All @@ -237,4 +239,5 @@ async def doit():

await asyncio.gather(*[doit() for _ in range(10)])

# stop proxies
await asyncio.gather(*(p.stop() for p in proxies))

0 comments on commit 6f6b6f6

Please sign in to comment.