Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure writer is always reset on completion (#7815) #7826

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/7815.bugfix
@@ -0,0 +1 @@
Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer`
74 changes: 49 additions & 25 deletions aiohttp/client_reqrep.py
Expand Up @@ -53,7 +53,13 @@
reify,
set_result,
)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
from .http import (
SERVER_SOFTWARE,
HttpVersion,
HttpVersion10,
HttpVersion11,
StreamWriter,
)
from .log import client_logger
from .streams import StreamReader
from .typedefs import (
Expand Down Expand Up @@ -241,7 +247,7 @@ class ClientRequest:
auth = None
response = None

_writer = None # async task for streaming data
__writer = None # async task for streaming data
_continue = None # waiter future for '100 Continue' response

# N.B.
Expand Down Expand Up @@ -332,6 +338,21 @@ def __init__(
traces = []
self._traces = traces

def __reset_writer(self, _: object = None) -> None:
self.__writer = None

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
if writer is not None:
writer.add_done_callback(self.__reset_writer)

def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

Expand Down Expand Up @@ -625,8 +646,6 @@ async def write_bytes(
else:
await writer.write_eof()
protocol.start_timeout()
finally:
self._writer = None

async def send(self, conn: "Connection") -> "ClientResponse":
# Specify request target:
Expand Down Expand Up @@ -711,16 +730,14 @@ async def send(self, conn: "Connection") -> "ClientResponse":

async def close(self) -> None:
if self._writer is not None:
try:
with contextlib.suppress(asyncio.CancelledError):
await self._writer
finally:
self._writer = None
with contextlib.suppress(asyncio.CancelledError):
await self._writer

def terminate(self) -> None:
if self._writer is not None:
if not self.loop.is_closed():
self._writer.cancel()
self._writer.remove_done_callback(self.__reset_writer)
self._writer = None

async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
Expand All @@ -740,9 +757,9 @@ class ClientResponse(HeadersMixin):
# but will be set by the start() method.
# As the end user will likely never see the None values, we cheat the types below.
# from the Status-Line of the response
version = None # HTTP-Version
status: int = None # type: ignore[assignment] # Status-Code
reason = None # Reason-Phrase
version: Optional[HttpVersion] = None # HTTP-Version
status: int = None # type: ignore[assignment] # Status-Code
reason: Optional[str] = None # Reason-Phrase

content: StreamReader = None # type: ignore[assignment] # Payload stream
_headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
Expand All @@ -754,6 +771,7 @@ class ClientResponse(HeadersMixin):
# post-init stage allows to not change ctor signature
_closed = True # to allow __del__ for non-initialized properly response
_released = False
__writer = None

def __init__(
self,
Expand Down Expand Up @@ -799,6 +817,21 @@ def __init__(
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

def __reset_writer(self, _: object = None) -> None:
self.__writer = None

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
if writer is not None:
writer.add_done_callback(self.__reset_writer)

@reify
def url(self) -> URL:
return self._url
Expand Down Expand Up @@ -863,7 +896,7 @@ def __repr__(self) -> str:
"ascii", "backslashreplace"
).decode("ascii")
else:
ascii_encodable_reason = self.reason
ascii_encodable_reason = "None"
print(
"<ClientResponse({}) [{} {}]>".format(
ascii_encodable_url, self.status, ascii_encodable_reason
Expand Down Expand Up @@ -1044,18 +1077,12 @@ def _release_connection(self) -> None:

async def _wait_released(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
await self._writer
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
if self._writer.done():
self._writer = None
else:
self._writer.cancel()
self._writer.cancel()
self._session = None

def _notify_content(self) -> None:
Expand All @@ -1066,10 +1093,7 @@ def _notify_content(self) -> None:

async def wait_for_close(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
await self._writer
self.release()

async def read(self) -> bytes:
Expand Down
20 changes: 16 additions & 4 deletions tests/test_client_request.py
Expand Up @@ -5,7 +5,7 @@
import urllib.parse
import zlib
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional
from unittest import mock

import pytest
Expand All @@ -24,6 +24,17 @@
from aiohttp.test_utils import make_mocked_coro


class WriterMock(mock.AsyncMock):
def __await__(self) -> None:
return self().__await__()

def add_done_callback(self, cb: Callable[[], None]) -> None:
"""Dummy method."""

def remove_done_callback(self, cb: Callable[[], None]) -> None:
"""Dummy method."""


@pytest.fixture
def make_request(loop):
request = None
Expand Down Expand Up @@ -1167,7 +1178,7 @@ def read(self, decode=False):
async def test_oserror_on_write_bytes(loop, conn) -> None:
req = ClientRequest("POST", URL("http://python.org/"), loop=loop)

writer = mock.Mock()
writer = WriterMock()
writer.write.side_effect = OSError

await req.write_bytes(writer, conn)
Expand All @@ -1183,7 +1194,8 @@ async def test_terminate(loop, conn) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
resp = await req.send(conn)
assert req._writer is not None
writer = req._writer = mock.Mock()
writer = req._writer = WriterMock()
writer.cancel = mock.Mock()

req.terminate()
assert req._writer is None
Expand All @@ -1201,7 +1213,7 @@ async def go():
req = ClientRequest("get", URL("http://python.org"))
resp = await req.send(conn)
assert req._writer is not None
writer = req._writer = mock.Mock()
writer = req._writer = WriterMock()

await asyncio.sleep(0.05)

Expand Down
4 changes: 4 additions & 0 deletions tests/test_client_response.py
Expand Up @@ -2,6 +2,7 @@

import gc
import sys
from typing import Callable
from unittest import mock

import pytest
Expand All @@ -19,6 +20,9 @@ class WriterMock(mock.AsyncMock):
def __await__(self) -> None:
return self().__await__()

def add_done_callback(self, cb: Callable[[], None]) -> None:
cb()

def done(self) -> bool:
return True

Expand Down
18 changes: 9 additions & 9 deletions tests/test_proxy.py
Expand Up @@ -202,7 +202,7 @@ def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_https_connect(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -386,7 +386,7 @@ def test_https_connect_certificate_error(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_https_connect_ssl_error(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -496,7 +496,7 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -555,7 +555,7 @@ def test_https_connect_resp_start_error(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -666,7 +666,7 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -737,7 +737,7 @@ def test_https_auth(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down