Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing cancelled async futures #2666

Merged
merged 18 commits into from Mar 29, 2023
2 changes: 2 additions & 0 deletions .github/workflows/integration.yaml
Expand Up @@ -51,6 +51,7 @@ jobs:
timeout-minutes: 30
strategy:
max-parallel: 15
fail-fast: false
chayim marked this conversation as resolved.
Show resolved Hide resolved
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
test-type: ['standalone', 'cluster']
Expand Down Expand Up @@ -108,6 +109,7 @@ jobs:
name: Install package from commit hash
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9']
steps:
Expand Down
101 changes: 69 additions & 32 deletions redis/asyncio/client.py
Expand Up @@ -500,28 +500,37 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()
async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
conn, args[0], *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect(nowait=True)
raise
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
await self.connection_pool.release(conn)

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()

return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down Expand Up @@ -760,10 +769,18 @@ async def _disconnect_raise_connect(self, conn, error):
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()

if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()

async def _try_execute(self, conn, command, *arg, **kwargs):
try:
return await command(*arg, **kwargs)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def _execute(self, conn, command, *args, **kwargs):
"""
Connect manually upon disconnection. If the Redis server is down,
Expand All @@ -772,9 +789,11 @@ async def _execute(self, conn, command, *args, **kwargs):
called by the # connection to resubscribe us to any channels and
patterns we were previously listening to
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: self._try_execute(conn, command, *args, **kwargs),
lambda error: self._disconnect_raise_connect(conn, error),
)
)

async def parse_response(self, block: bool = True, timeout: float = 0):
Expand Down Expand Up @@ -1176,6 +1195,18 @@ async def _disconnect_reset_raise(self, conn, error):
await self.reset()
raise

async def _try_send_command_parse_response(self, conn, *args, **options):
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, args[0], *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
except asyncio.CancelledError:
await conn.disconnect()
raise

async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
Expand All @@ -1191,13 +1222,13 @@ async def immediate_execute_command(self, *args, **options):
command_name, self.shard_hint
)
self.connection = conn

return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise(conn, error),
)
try:
return await asyncio.shield(
self._try_send_command_parse_response(conn, *args, **options)
)
except asyncio.CancelledError:
chayim marked this conversation as resolved.
Show resolved Hide resolved
await conn.disconnect()
raise

def pipeline_execute_command(self, *args, **options):
"""
Expand Down Expand Up @@ -1364,6 +1395,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
await self.reset()
raise

async def _try_execute(self, conn, execute, stack, raise_on_error):
try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()

async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
Expand All @@ -1386,17 +1430,10 @@ async def execute(self, raise_on_error: bool = True):

try:
return await asyncio.shield(
conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_reset(conn, error),
)
self._try_execute(conn, execute, stack, raise_on_error)
)
except asyncio.CancelledError:
# not supposed to be possible, yet here we are
await conn.disconnect(nowait=True)
raise
finally:
await self.reset()
except RuntimeError:
chayim marked this conversation as resolved.
Show resolved Hide resolved
self.reset()

async def discard(self):
"""Flushes all previously queued commands
Expand Down
21 changes: 14 additions & 7 deletions redis/asyncio/cluster.py
Expand Up @@ -1016,6 +1016,19 @@ async def _parse_and_release(self, connection, *args, **kwargs):
finally:
self._free.append(connection)

async def _try_parse_response(self, cmd, connection, ret):
try:
cmd.result = await asyncio.shield(
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
)
except asyncio.CancelledError:
await connection.disconnect(nowait=True)
raise
except Exception as e:
cmd.result = e
ret = True
return ret

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
Expand All @@ -1028,13 +1041,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Read responses
ret = False
for cmd in commands:
try:
cmd.result = await self.parse_response(
connection, cmd.args[0], **cmd.kwargs
)
except Exception as e:
cmd.result = e
ret = True
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))

# Release connection
self._free.append(connection)
Expand Down
17 changes: 0 additions & 17 deletions tests/test_asyncio/test_cluster.py
Expand Up @@ -340,23 +340,6 @@ async def test_from_url(self, request: FixtureRequest) -> None:
rc = RedisCluster.from_url("rediss://localhost:16379")
assert rc.connection_kwargs["connection_class"] is SSLConnection

async def test_asynckills(self, r) -> None:

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection is left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"

async def test_max_connections(
self, create_redis: Callable[..., RedisCluster]
) -> None:
Expand Down
21 changes: 0 additions & 21 deletions tests/test_asyncio/test_connection.py
Expand Up @@ -44,27 +44,6 @@ async def test_invalid_response(create_redis):
await r.connection.disconnect()


async def test_asynckills():

for b in [True, False]:
r = Redis(single_connection_client=b)

await r.set("foo", "foo")
await r.set("bar", "bar")

t = asyncio.create_task(r.get("foo"))
await asyncio.sleep(1)
t.cancel()
try:
await t
except asyncio.CancelledError:
pytest.fail("connection left open with unread response")

assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"


@pytest.mark.onlynoncluster
async def test_single_connection():
"""Test that concurrent requests on a single client are synchronised."""
Expand Down