diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 07d0724f79..a427975a4e 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -182,37 +182,67 @@ async def op(pipe): @pytest.mark.onlycluster async def test_cluster(request, redis_addr): - # TODO: This test actually doesn't work. Once the RedisCluster initializes, - # it will re-connect to the nodes as advertised by the cluster, bypassing - # the single DelayProxy we set up. - # to work around this, we really would nedd a port-remapper for the RedisCluster - - redis_addr = redis_addr[0], 6372 # use the cluster port delay = 0.1 - dp = DelayProxy(addr=("127.0.0.1", 5381), redis_addr=redis_addr) - await dp.start() + cluster_port = 6372 + remap_base = 7372 + n_nodes = 6 + + def remap(host, port): + return host, remap_base + port - cluster_port + + proxies = [] + for i in range(n_nodes): + port = cluster_port + i + remapped = remap_base + i + forward_addr = redis_addr[0], port + proxy = DelayProxy(addr=("127.0.0.1", remapped), redis_addr=forward_addr) + proxies.append(proxy) + + # start proxies + await asyncio.gather(*[p.start() for p in proxies]) + + def all_clear(): + for p in proxies: + p.send_event.clear() + + async def wait_for_send(): + asyncio.wait( + [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED + ) + + @contextlib.contextmanager + def set_delay(delay: float): + with contextlib.ExitStack() as stack: + for p in proxies: + stack.enter_context(p.set_delay(delay)) + yield - with contextlib.closing(RedisCluster.from_url("redis://127.0.0.1:5381")) as r: + with contextlib.closing( + RedisCluster.from_url(f"redis://127.0.0.1:{remap_base}", host_port_remap=remap) + ) as r: await r.initialize() await r.set("foo", "foo") await r.set("bar", "bar") async def op(r): - with dp.set_delay(delay): + with set_delay(delay): return await r.get("foo") - dp.send_event.clear() + all_clear() t = asyncio.create_task(op(r)) - # await dp.send_event.wait() # won"t work, because DelayProxy is by-passed + # Wait for whichever DelayProxy gets the request first + await wait_for_send() await asyncio.sleep(0.01) t.cancel() - try: + with pytest.raises(asyncio.CancelledError): await t - except asyncio.CancelledError: - pass - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" + # try a number of requests to excercise all the connections + async def doit(): + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" - await dp.stop() + await asyncio.gather(*[doit() for _ in range(10)]) + + await asyncio.gather(*(p.stop() for p in proxies))