diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 66060a0668..9026f57994 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -17,31 +17,12 @@ def redis_addr(request): return host, int(port) -async def pipe( - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - proxy: "DelayProxy", - name="", - event: asyncio.Event = None, -): - while True: - data = await reader.read(1000) - if not data: - break - if event: - event.set() - await asyncio.sleep(proxy.delay) - writer.write(data) - await writer.drain() - - class DelayProxy: def __init__(self, addr, redis_addr, delay: float): self.addr = addr self.redis_addr = redis_addr self.delay = delay self.send_event = asyncio.Event() - self.redis_streams = None async def start(self): # test that we can connect to redis @@ -52,31 +33,48 @@ async def start(self): self.ROUTINE = asyncio.create_task(self.server.serve_forever()) @contextlib.contextmanager - def override(self, delay: float = 0.0): + def set_delay(self, delay: float = 0.0): """ Allow to override the delay for parts of tests which aren't time dependent, to speed up execution. """ - old = self.delay + old_delay = self.delay self.delay = delay try: yield finally: - self.delay = old + self.delay = old_delay async def handle(self, reader, writer): # establish connection to redis redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) try: pipe1 = asyncio.create_task( - pipe(reader, redis_writer, self, "to redis:", self.send_event) + self.pipe(reader, redis_writer, "to redis:", self.send_event) ) - pipe2 = asyncio.create_task(pipe(redis_reader, writer, self, "from redis:")) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:")) await asyncio.gather(pipe1, pipe2) finally: redis_writer.close() redis_reader.close() + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + name="", + event: asyncio.Event = None, + ): + while True: + data = await reader.read(1000) + if not data: + break + if event: + event.set() + await asyncio.sleep(self.delay) + writer.write(data) + await writer.drain() + async def stop(self): # clean up enough so that we can reuse the looper self.ROUTINE.cancel() @@ -101,7 +99,7 @@ async def test_standalone(delay, redis_addr): # note that we connect to proxy, rather than to Redis directly async with Redis(host="127.0.0.1", port=5380, single_connection_client=b) as r: - with dp.override(): + with dp.set_delay(0): await r.set("foo", "foo") await r.set("bar", "bar") @@ -117,7 +115,7 @@ async def test_standalone(delay, redis_addr): # make sure that our previous request, cancelled while waiting for # a repsponse, didn't leave the connection open andin a bad state - with dp.override(): + with dp.set_delay(0): assert await r.get("bar") == b"bar" assert await r.ping() assert await r.get("foo") == b"foo" @@ -132,7 +130,7 @@ async def test_standalone_pipeline(delay, redis_addr): await dp.start() for b in [True, False]: async with Redis(host="127.0.0.1", port=5380, single_connection_client=b) as r: - with dp.override(): + with dp.set_delay(0): await r.set("foo", "foo") await r.set("bar", "bar") @@ -154,7 +152,7 @@ async def test_standalone_pipeline(delay, redis_addr): # we have now cancelled the pieline in the middle of a request, make sure # that the connection is still usable - with dp.override(): + with dp.set_delay(0): pipe.get("bar") pipe.ping() pipe.get("foo") @@ -205,10 +203,10 @@ async def any_wait(): ) @contextlib.contextmanager - def all_override(delay: int = 0): + def set_delay(delay: int = 0): with contextlib.ExitStack() as stack: for p in proxies: - stack.enter_context(p.override(delay=delay)) + stack.enter_context(p.set_delay(delay)) yield # start proxies @@ -222,9 +220,9 @@ def all_override(delay: int = 0): await r.set("bar", "bar") all_clear() - with all_override(delay=delay): + with set_delay(delay=delay): t = asyncio.create_task(r.get("foo")) - # cannot wait on the send event, we don't know which node will be used + # One of the proxies will handle our request, wait for it to send await any_wait() await asyncio.sleep(delay) t.cancel()