diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index ab59f8c0cc..ddf7e97980 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -168,34 +168,66 @@ async def test_standalone_pipeline(delay, redis_addr): @pytest.mark.onlycluster async def test_cluster(request, redis_addr): - # TODO: This test actually doesn't work. Once the RedisCluster initializes, - # it will re-connect to the nodes as advertised by the cluster, bypassing - # the single DelayProxy we set up. - # to work around this, we really would nedd a port-remapper for the RedisCluster + delay = 0.1 + cluster_port = 6372 + remap_base = 7372 + n_nodes = 6 + + remap = [] + proxies = [] + for i in range(n_nodes): + port = cluster_port + i + remapped = remap_base + i + remap.append({"from_port": port, "to_port": remapped}) + forward_addr = redis_addr[0], port + proxy = DelayProxy( + addr=("127.0.0.1", remapped), redis_addr=forward_addr, delay=delay + ) + proxies.append(proxy) - redis_addr = redis_addr[0], 6372 # use the cluster port - dp = DelayProxy(addr=("127.0.0.1", 5381), redis_addr=redis_addr, delay=0.1) - await dp.start() + # start proxies + await asyncio.gather(*[p.start() for p in proxies]) + + def all_clear(): + for p in proxies: + p.send_event.clear() - r = RedisCluster.from_url("redis://127.0.0.1:5381") - await r.initialize() - with dp.override(): + async def wait_for_send(): + asyncio.wait( + [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED + ) + + @contextlib.contextmanager + def override(): + with contextlib.ExitStack() as stack: + for p in proxies: + stack.enter_context(p.override()) + yield + + with override(): + r = RedisCluster.from_url( + f"redis://127.0.0.1:{remap_base}", host_port_remap=remap + ) + await r.initialize() await r.set("foo", "foo") await r.set("bar", "bar") - dp.send_event.clear() + all_clear() t = asyncio.create_task(r.get("foo")) - # await dp.send_event.wait() # won"t work, because DelayProxy is by-passed - await asyncio.sleep(0.05) + # cannot wait on the send event, we don't know which node will be used + await wait_for_send() + await asyncio.sleep(delay) t.cancel() - try: + with pytest.raises(asyncio.CancelledError): await t - except asyncio.CancelledError: - pass - with dp.override(): - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" + with override(): + # try a number of requests to excercise all the connections + async def doit(): + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" - await dp.stop() + await asyncio.gather(*[doit() for _ in range(10)]) + + await asyncio.gather(*(p.stop() for p in proxies))