Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: replace async_timeout by asyncio.timeout #2602

Merged
merged 1 commit into from Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
@@ -1,3 +1,4 @@
* Use asyncio.timeout() instead of async_timeout.timeout() for python >= 3.11 (#2602)
* Add test and fix async HiredisParser when reading during a disconnect() (#2349)
* Use hiredis-py pack_command if available.
* Support `.unlink()` in ClusterPipeline
Expand Down
21 changes: 13 additions & 8 deletions redis/asyncio/connection.py
Expand Up @@ -5,6 +5,7 @@
import os
import socket
import ssl
import sys
import threading
import weakref
from itertools import chain
Expand All @@ -24,7 +25,11 @@
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse

import async_timeout
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout


from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
Expand Down Expand Up @@ -242,7 +247,7 @@ async def can_read_destructive(self) -> bool:
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
async with async_timeout.timeout(0):
async with async_timeout(0):
return await self._stream.read(1)
except asyncio.TimeoutError:
return False
Expand Down Expand Up @@ -381,7 +386,7 @@ async def can_read_destructive(self):
if self._reader.gets():
return True
try:
async with async_timeout.timeout(0):
async with async_timeout(0):
return await self.read_from_socket()
except asyncio.TimeoutError:
return False
Expand Down Expand Up @@ -636,7 +641,7 @@ async def connect(self):

async def _connect(self):
"""Create a TCP socket connection"""
async with async_timeout.timeout(self.socket_connect_timeout):
async with async_timeout(self.socket_connect_timeout):
reader, writer = await asyncio.open_connection(
host=self.host,
port=self.port,
Expand Down Expand Up @@ -723,7 +728,7 @@ async def on_connect(self) -> None:
async def disconnect(self, nowait: bool = False) -> None:
"""Disconnects from the Redis server"""
try:
async with async_timeout.timeout(self.socket_connect_timeout):
async with async_timeout(self.socket_connect_timeout):
self._parser.on_disconnect()
if not self.is_connected:
return
Expand Down Expand Up @@ -828,7 +833,7 @@ async def read_response(
read_timeout = timeout if timeout is not None else self.socket_timeout
try:
if read_timeout is not None:
async with async_timeout.timeout(read_timeout):
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
Expand Down Expand Up @@ -1119,7 +1124,7 @@ def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
return pieces

async def _connect(self):
async with async_timeout.timeout(self.socket_connect_timeout):
async with async_timeout(self.socket_connect_timeout):
reader, writer = await asyncio.open_unix_connection(path=self.path)
self._reader = reader
self._writer = writer
Expand Down Expand Up @@ -1590,7 +1595,7 @@ async def get_connection(self, command_name, *keys, **options):
# self.timeout then raise a ``ConnectionError``.
connection = None
try:
async with async_timeout.timeout(self.timeout):
async with async_timeout(self.timeout):
connection = await self.pool.get()
except (asyncio.QueueEmpty, asyncio.TimeoutError):
# Note that this is not caught by the redis client and will be
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -34,7 +34,7 @@
install_requires=[
'importlib-metadata >= 1.0; python_version < "3.8"',
'typing-extensions; python_version<"3.8"',
"async-timeout>=4.0.2",
'async-timeout>=4.0.2; python_version<"3.11"',
],
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down
22 changes: 13 additions & 9 deletions tests/test_asyncio/test_pubsub.py
Expand Up @@ -5,7 +5,11 @@
from typing import Optional
from unittest.mock import patch

import async_timeout
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout

import pytest
import pytest_asyncio

Expand All @@ -21,7 +25,7 @@ def with_timeout(t):
def wrapper(corofunc):
@functools.wraps(corofunc)
async def run(*args, **kwargs):
async with async_timeout.timeout(t):
async with async_timeout(t):
return await corofunc(*args, **kwargs)

return run
Expand Down Expand Up @@ -648,7 +652,7 @@ async def test_reconnect_listen(self, r: redis.Redis, pubsub):

async def loop():
# must make sure the task exits
async with async_timeout.timeout(2):
async with async_timeout(2):
nonlocal interrupt
await pubsub.subscribe("foo")
while True:
Expand Down Expand Up @@ -677,7 +681,7 @@ async def loop_step():

task = asyncio.get_running_loop().create_task(loop())
# get the initial connect message
async with async_timeout.timeout(1):
async with async_timeout(1):
message = await messages.get()
assert message == {
"channel": b"foo",
Expand Down Expand Up @@ -776,7 +780,7 @@ def callback(message):
if n == 1:
break
await asyncio.sleep(0.1)
async with async_timeout.timeout(0.1):
async with async_timeout(0.1):
message = await messages.get()
task.cancel()
# we expect a cancelled error, not the Runtime error
Expand Down Expand Up @@ -839,7 +843,7 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method):
Test that a socket error will cause reconnect
"""
try:
async with async_timeout.timeout(self.timeout):
async with async_timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
Expand Down Expand Up @@ -868,7 +872,7 @@ async def test_reconnect_disconnect(self, r: redis.Redis, method):
Test that a manual disconnect() will cause reconnect
"""
try:
async with async_timeout.timeout(self.timeout):
async with async_timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
Expand Down Expand Up @@ -923,7 +927,7 @@ async def loop_step_get_message(self):
async def loop_step_listen(self):
# get a single message via listen()
try:
async with async_timeout.timeout(0.1):
async with async_timeout(0.1):
async for message in self.pubsub.listen():
await self.messages.put(message)
return True
Expand All @@ -947,7 +951,7 @@ async def test_outer_timeout(self, r: redis.Redis):
assert pubsub.connection.is_connected

async def get_msg_or_timeout(timeout=0.1):
async with async_timeout.timeout(timeout):
async with async_timeout(timeout):
# blocking method to return messages
while True:
response = await pubsub.parse_response(block=True)
Expand Down