Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure writer is always reset on completion #7815

Merged
merged 12 commits into from Nov 12, 2023
1 change: 1 addition & 0 deletions CHANGES/7815.bugfix
@@ -0,0 +1 @@
Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer`
72 changes: 48 additions & 24 deletions aiohttp/client_reqrep.py
Expand Up @@ -56,7 +56,13 @@
reify,
set_result,
)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
from .http import (
SERVER_SOFTWARE,
HttpVersion,
HttpVersion10,
HttpVersion11,
StreamWriter,
)
from .log import client_logger
from .streams import StreamReader
from .typedefs import (
Expand Down Expand Up @@ -178,7 +184,7 @@
auth = None
response = None

_writer = None # async task for streaming data
__writer = None # async task for streaming data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dreamsorcerer FYI using double leading underscored is usually discouraged due to how it's re-exposed in the inherited objects...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was on purpose, if someone messes with this in an inherited class, they may cause the program to hang or similar.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dreamsorcerer but why do you want it to be exposed for messing it up in the first place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand what you mean? I'm trying to discourage anyone from accessing/setting this attribute directly.

Copy link
Member

@webknjaz webknjaz Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dreamsorcerer yeah, that's a common mistake and is not the primary use case of the name mangling. It's actually discouraged to use leading double underscores because of the side effects this may cause for people who aren't supposed to know the base class tree implementation details. A single underscore is preferred.

Double leading underscore is advertised to be a hack for dealing with name clashes in the context of inheritance. The end-users would need to know about this implementation detail and never call their classes and attributes with the same name as one of the indirect base classes up the chain.

One CPython Core Dev once told me that the name mangling mechanism is a half-baked band-aid.

I couldn't find any clearly documented dangers of using this, so I had to draft my own example. Here you go:

Python 3.12.0rc3 (main, Sep 22 2023, 15:37:03) [GCC 12.3.1 20230526] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> class SomeBase:
...     def __init__(self): self.__private_thing = 'Framework Super Base'
...     def do_stuff(self): print(f'The ultimate base: {self.__private_thing=}')
... 
>>> class AFrameworkProvidedSomeClass(SomeBase): pass
... 
>>> # <<<<< BOUNDARY BETWEEN THE FRAMEWORK AND THE END-USER APP >>>>>
>>> 
>>> class SomeBase(AFrameworkProvidedSomeClass):  # End-user project defining their internal base for reuse
...     def __init__(self):
...         super().__init__()
...         self.__private_thing = 'End-user App Base'
...     def show_that_private_thing(self): print(f'Our attr: {self.__private_thing=}')
... 
>>> class BaseForMyApp(SomeBase): pass
... 
>>> class MyAppFinal(BaseForMyApp):
...     def do_our_thing(self): print(f'{dir(self)=}')
... 
>>> 
>>> 
>>> # <<<<< THE END-USER APP JUST WORKS WITH THE CLASS ON THEIR SIDE, NOT KNOWING THE FRAMEWORK IMPLEMENTATION DETAILS >>>>>
>>> 
>>> 
>>> app = MyAppFinal()
>>> app.do_stuff()  # This private API should be predictable because, ...right? Nope!
The ultimate base: self.__private_thing='End-user App Base'
>>> app.show_that_private_thing()
Our attr: self.__private_thing='End-user App Base'
>>> app._SomeBase__private_thing
'End-user App Base'
>>> app.do_our_thing()
dir(self)=['_SomeBase__private_thing', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'do_our_thing', 'do_stuff', 'show_that_private_thing']
>>> MyAppFinal.__mro__
(<class '__main__.MyAppFinal'>, <class '__main__.BaseForMyApp'>, <class '__main__.SomeBase'>, <class '__main__.AFrameworkProvidedSomeClass'>, <class '__main__.SomeBase'>, <class 'object'>)
>>> 
>>> 
>>> # What if we call it differently?
>>> class AnotherBase(AFrameworkProvidedSomeClass):  # End-user project defining another internal base for reuse
...     def __init__(self):
...         super().__init__()
...         self.__private_thing = 'App Base'
...     def show_that_private_thing(self): print(f'Our attr: {self.__private_thing=}')
... 
>>> class BaseForMyApp(AnotherBase): pass
...
>>> class MyAppFinal(BaseForMyApp):
...     def do_our_thing(self): print(f'{dir(self)=}')
... 
>>> app = MyAppFinal()
>>> app.do_stuff()  # This private API happens to be predictable because, ...the end-user was lucky not use use the same names. By accident.
The ultimate base: self.__private_thing='Framework Super Base'
>>> app.show_that_private_thing()
Our attr: self.__private_thing='App Base'
>>> app._SomeBase__private_thing
'Framework Super Base'
>>> app.do_our_thing()
dir(self)=['_AnotherBase__private_thing', '_SomeBase__private_thing', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'do_our_thing', 'do_stuff', 'show_that_private_thing']
>>> MyAppFinal.__mro__
(<class '__main__.MyAppFinal'>, <class '__main__.BaseForMyApp'>, <class '__main__.AnotherBase'>, <class '__main__.AFrameworkProvidedSomeClass'>, <class '__main__.SomeBase'>, <class 'object'>)

In this example, the end-user unknowingly adds a class that happens to have the same base name as something from a framework. And decides to use a "__private" attribute that they would "own".
This shouldn't influence anything, right? Well, no. They try to use the framework's public API and it "works", except that their own "private" attribute leaked into the namespace where the framework's attribute with the same name is defined, effectively shadowing it. This is a straight way to break the framework guarantees, never realizing it. And their editor helpfully doesn't auto-complete the super-base private attribute (also because it's actually exposed as _SomeBase__private_thing at the time it's evaluated).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and another one:

Python 3.12.0rc3 (main, Sep 22 2023, 15:37:03) [GCC 12.3.1 20230526] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> class SomeBase:
...     def __init__(self): self.__private_thing = 'Framework Super Base'
...     def do_stuff(self): print(f'The ultimate base: {self.__private_thing=}')
... 
>>> class AFrameworkProvidedSomeClass(SomeBase): pass
... 
>>> class AnotherBase(AFrameworkProvidedSomeClass):
...     def access_app_stuff(self): print(f'The private app thing: {self.__private_thing=}')
... 
>>> class SomeBase(AnotherBase):
...     def access_app_stuff(self): print(f'The private app thing: {self.__private_thing=}')
... 
>>> AnotherBase().do_stuff()
The ultimate base: self.__private_thing='Framework Super Base'
>>> SomeBase().do_stuff()
The ultimate base: self.__private_thing='Framework Super Base'
>>> AnotherBase().access_app_stuff()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 2, in access_app_stuff
AttributeError: 'AnotherBase' object has no attribute '_AnotherBase__private_thing'. Did you mean: '_SomeBase__private_thing'?
>>> SomeBase().access_app_stuff()
The private app thing: self.__private_thing='Framework Super Base'
>>>

_continue = None # waiter future for '100 Continue' response

# N.B.
Expand Down Expand Up @@ -265,6 +271,21 @@
traces = []
self._traces = traces

def __reset_writer(self, _: object = None) -> None:
self.__writer = None

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
if writer is not None:
writer.add_done_callback(self.__reset_writer)

def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

Expand Down Expand Up @@ -563,8 +584,6 @@
else:
await writer.write_eof()
protocol.start_timeout()
finally:
self._writer = None

async def send(self, conn: "Connection") -> "ClientResponse":
# Specify request target:
Expand Down Expand Up @@ -649,16 +668,14 @@

async def close(self) -> None:
if self._writer is not None:
try:
with contextlib.suppress(asyncio.CancelledError):
await self._writer
finally:
self._writer = None
with contextlib.suppress(asyncio.CancelledError):
await self._writer

def terminate(self) -> None:
if self._writer is not None:
if not self.loop.is_closed():
self._writer.cancel()
self._writer.remove_done_callback(self.__reset_writer)
self._writer = None

async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
Expand All @@ -677,9 +694,9 @@
# but will be set by the start() method.
# As the end user will likely never see the None values, we cheat the types below.
# from the Status-Line of the response
version = None # HTTP-Version
version: Optional[HttpVersion] = None # HTTP-Version
status: int = None # type: ignore[assignment] # Status-Code
reason = None # Reason-Phrase
reason: Optional[str] = None # Reason-Phrase

content: StreamReader = None # type: ignore[assignment] # Payload stream
_headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
Expand All @@ -691,6 +708,7 @@
# post-init stage allows to not change ctor signature
_closed = True # to allow __del__ for non-initialized properly response
_released = False
__writer = None

def __init__(
self,
Expand Down Expand Up @@ -737,6 +755,21 @@
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

def __reset_writer(self, _: object = None) -> None:
self.__writer = None

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)

Check warning on line 768 in aiohttp/client_reqrep.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/client_reqrep.py#L768

Added line #L768 was not covered by tests
self.__writer = writer
if writer is not None:
writer.add_done_callback(self.__reset_writer)

@reify
def url(self) -> URL:
return self._url
Expand Down Expand Up @@ -797,7 +830,7 @@
"ascii", "backslashreplace"
).decode("ascii")
else:
ascii_encodable_reason = self.reason
ascii_encodable_reason = "None"
print(
"<ClientResponse({}) [{} {}]>".format(
ascii_encodable_url, self.status, ascii_encodable_reason
Expand Down Expand Up @@ -978,18 +1011,12 @@

async def _wait_released(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
await self._writer
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
if self._writer.done():
self._writer = None
else:
self._writer.cancel()
self._writer.cancel()
self._session = None

def _notify_content(self) -> None:
Expand All @@ -1001,10 +1028,7 @@

async def wait_for_close(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
await self._writer
self.release()

async def read(self) -> bytes:
Expand Down
5 changes: 4 additions & 1 deletion tests/test_client_response.py
Expand Up @@ -4,7 +4,7 @@
import gc
import sys
from json import JSONDecodeError
from typing import Any
from typing import Any, Callable
from unittest import mock

import pytest
Expand All @@ -22,6 +22,9 @@ class WriterMock(mock.AsyncMock):
def __await__(self) -> None:
return self().__await__()

def add_done_callback(self, cb: Callable[[], None]) -> None:
cb()

def done(self) -> bool:
return True

Expand Down
18 changes: 9 additions & 9 deletions tests/test_proxy.py
Expand Up @@ -199,7 +199,7 @@ def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -323,7 +323,7 @@ def test_https_connect(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -383,7 +383,7 @@ def test_https_connect_certificate_error(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -437,7 +437,7 @@ def test_https_connect_ssl_error(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -493,7 +493,7 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -552,7 +552,7 @@ def test_https_connect_resp_start_error(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -663,7 +663,7 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -734,7 +734,7 @@ def test_https_auth(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down