Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing cancelled async futures #2666

Merged
merged 18 commits into from Mar 29, 2023
2 changes: 2 additions & 0 deletions .github/workflows/integration.yaml
Expand Up @@ -51,6 +51,7 @@ jobs:
timeout-minutes: 30
strategy:
max-parallel: 15
fail-fast: false
chayim marked this conversation as resolved.
Show resolved Hide resolved
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
test-type: ['standalone', 'cluster']
Expand Down Expand Up @@ -108,6 +109,7 @@ jobs:
name: Install package from commit hash
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
steps:
Expand Down
48 changes: 32 additions & 16 deletions redis/asyncio/client.py
Expand Up @@ -511,12 +511,17 @@ async def execute_command(self, *args, **options):
if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
return await asyncio.shield(
chayim marked this conversation as resolved.
Show resolved Hide resolved
conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
)
except asyncio.CancelledError:
await conn.disconnect(nowait=True)
raise
finally:
if self.single_connection_client:
self._single_conn_lock.release()
Expand Down Expand Up @@ -772,10 +777,16 @@ async def _execute(self, conn, command, *args, **kwargs):
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)
try:
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)
)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def parse_response(self, block: bool = True, timeout: float = 0):
"""Parse the response from a publish/subscribe command"""
Expand Down Expand Up @@ -1191,13 +1202,18 @@ async def immediate_execute_command(self, *args, **options):
command_name, self.shard_hint
)
self.connection = conn

return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
try:
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
)
except asyncio.CancelledError:
chayim marked this conversation as resolved.
Show resolved Hide resolved
await conn.disconnect()
raise

def pipeline_execute_command(self, *args, **options):
"""
Expand Down
7 changes: 5 additions & 2 deletions redis/asyncio/cluster.py
Expand Up @@ -1029,9 +1029,12 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
ret = False
for cmd in commands:
try:
cmd.result = await self.parse_response(
connection, cmd.args[0], **cmd.kwargs
cmd.result = await asyncio.shield(
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
)
except asyncio.CancelledError:
await connection.disconnect(nowait=True)
raise
except Exception as e:
cmd.result = e
ret = True
Expand Down
17 changes: 0 additions & 17 deletions tests/test_asyncio/test_cluster.py
Expand Up @@ -340,23 +340,6 @@ async def test_from_url(self, request: FixtureRequest) -> None:
rc = RedisCluster.from_url("rediss://localhost:16379")
assert rc.connection_kwargs["connection_class"] is SSLConnection

async def test_asynckills(self, r) -> None:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

async def test_max_connections(
self, create_redis: Callable[..., RedisCluster]
) -> None:
Expand Down
21 changes: 0 additions & 21 deletions tests/test_asyncio/test_connection.py
Expand Up @@ -44,27 +44,6 @@ async def test_invalid_response(create_redis):
await r.connection.disconnect()


async def test_asynckills():

for b in [True, False]:
r = Redis(single_connection_client=b)

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"


@pytest.mark.onlynoncluster
async def test_single_connection():
"""Test that concurrent requests on a single client are synchronised."""
Expand Down
129 changes: 129 additions & 0 deletions tests/test_asyncio/test_cwe_404.py
@@ -0,0 +1,129 @@
import asyncio
import sys

import pytest

from redis.asyncio import Redis
from redis.asyncio.cluster import RedisCluster


async def pipe(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
):
while True:
data = await reader.read(1000)
if not data:
break
await asyncio.sleep(delay)
writer.write(data)
await writer.drain()


class DelayProxy:
def __init__(self, addr, redis_addr, delay: float):
self.addr = addr
self.redis_addr = redis_addr
self.delay = delay

async def start(self):
self.server = await asyncio.start_server(self.handle, *self.addr)
self.ROUTINE = asyncio.create_task(self.server.serve_forever())

async def handle(self, reader, writer):
# establish connection to redis
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
pipe2 = asyncio.create_task(
pipe(redis_reader, writer, self.delay, "from redis:")
)
await asyncio.gather(pipe1, pipe2)

async def stop(self):
# clean up enough so that we can reuse the looper
self.ROUTINE.cancel()
loop = self.server.get_loop()
await loop.shutdown_asyncgens()


async def test_standalone():

# create a tcp socket proxy that relays data to Redis and back,
# inserting 0.1 seconds of delay
dp = DelayProxy(addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=0.1)
await dp.start()

for b in [True, False]:
# note that we connect to proxy, rather than to Redis directly
async with Redis(host="localhost", port=5380, single_connection_client=b) as r:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(0.050)
t.cancel()
try:
await t
sys.stderr.write("try again, we did not cancel the task in time\n")
except asyncio.CancelledError:
sys.stderr.write(
"canceled task, connection is left open with unread response\n"
)

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

await dp.stop()


async def test_standalone_pipeline():
dp = DelayProxy(addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=0.1)
await dp.start()
async with Redis(host="localhost", port=5380) as r:
await r.set("foo", "foo")
await r.set("bar", "bar")

pipe = r.pipeline()

pipe2 = r.pipeline()
pipe2.get("bar")
pipe2.ping()
pipe2.get("foo")

t = asyncio.create_task(pipe.get("foo").execute())
await asyncio.sleep(0.050)
t.cancel()

pipe.get("bar")
pipe.ping()
pipe.get("foo")
assert await pipe.execute() == [b"foo", b"bar", True, b"foo"]
assert await pipe2.execute() == [b"bar", True, b"foo"]

await dp.stop()


async def test_cluster(request):

dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1)
await dp.start()

r = RedisCluster.from_url("redis://localhost:5381")
await r.initialize()
await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(0.050)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

await dp.stop()