Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix websocket connection leak #7978

Merged
merged 13 commits into from
Dec 18, 2023
1 change: 1 addition & 0 deletions CHANGES/7978.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix websocket connection leak
93 changes: 53 additions & 40 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ def _send_heartbeat(self) -> None:
def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._closed = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
bdraco marked this conversation as resolved.
Show resolved Hide resolved
self._exception = asyncio.TimeoutError()
self._req.transport.close()

async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
# make pre-check to don't hide it by do_handshake() exceptions
Expand Down Expand Up @@ -382,7 +381,10 @@ async def write_eof(self) -> None: # type: ignore[override]
await self.close()
self._eof_sent = True

async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
async def close(
self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
) -> bool:
"""Close websocket connection."""
if self._writer is None:
raise RuntimeError("Call .prepare() first")

Expand All @@ -396,46 +398,53 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting

if not self._closed:
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
return True
if self._closed:
return False

if self._closing:
return True
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
if drain:
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True

reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
return True
if self._closing:
return True

if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
return True
reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
except asyncio.CancelledError:
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True

self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = asyncio.TimeoutError()
if msg.type == WSMsgType.CLOSE:
self._set_code_close_transport(msg.data)
return True
else:
return False

self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = asyncio.TimeoutError()
return True

def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
self._close_code = code
if self._req is not None and self._req.transport is not None:
self._req.transport.close()

async def receive(self, timeout: Optional[float] = None) -> WSMessage:
if self._reader is None:
Expand Down Expand Up @@ -466,7 +475,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
set_result(waiter, True)
self._waiting = None
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except EofStream:
self._close_code = WSCloseCode.OK
Expand All @@ -488,7 +497,11 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
self._close_code = msg.data
# Could be closed while awaiting reader.
if not self._closed and self._autoclose: # type: ignore[redundant-expr]
await self.close()
# The client is likely going to close the
# connection out from under us so we do not
# want to drain any pending writes as it will
# likely result writing to a broken pipe.
await self.close(drain=False)
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
Expand Down
12 changes: 11 additions & 1 deletion docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,14 @@ and :ref:`aiohttp-web-signals` handlers::

.. versionadded:: 3.3

:param bool autoclose: Close connection when the client sends
a :const:`~aiohttp.WSMsgType.CLOSE` message,
``True`` by default. If set to ``False``,
the connection is not closed and the
caller is responsible for calling
``request.transport.close()`` to avoid
leaking resources.


The class supports ``async for`` statement for iterating over
incoming messages::
Expand Down Expand Up @@ -1164,7 +1172,7 @@ and :ref:`aiohttp-web-signals` handlers::
The method is converted into :term:`coroutine`,
*compress* parameter added.

.. method:: close(*, code=WSCloseCode.OK, message=b'')
.. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True)
:async:

A :ref:`coroutine<coroutine>` that initiates closing
Expand All @@ -1178,6 +1186,8 @@ and :ref:`aiohttp-web-signals` handlers::
:class:`str` (converted to *UTF-8* encoded bytes)
or :class:`bytes`.

:param bool drain: drain outgoing buffer before closing connection.

:raise RuntimeError: if connection is not started

.. method:: receive(timeout=None)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore
import asyncio
import time
from typing import Any
from unittest import mock

Expand Down Expand Up @@ -139,6 +140,20 @@
await ws.write(b"data")


async def test_heartbeat_timeout(make_request: Any) -> None:
"""Verify the transport is closed when the heartbeat timeout is reached."""
loop = asyncio.get_running_loop()
future = loop.create_future()
req = make_request("GET", "/")
lowest_time = time.get_clock_info("monotonic").resolution
req._protocol._timeout_ceil_threshold = lowest_time
ws = WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time)
await ws.prepare(req)
ws._req.transport.close.side_effect = lambda: future.set_result(None)
await future
Dismissed Show dismissed Hide dismissed
assert ws.closed


def test_websocket_ready() -> None:
websocket_ready = WebSocketReady(True, "chat")
assert websocket_ready.ok is True
Expand Down Expand Up @@ -207,6 +222,7 @@
await ws.prepare(req)
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()
assert len(ws._req.transport.close.mock_calls) == 1

with pytest.raises(ConnectionError):
await ws.send_str("string")
Expand Down Expand Up @@ -263,6 +279,8 @@
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
assert await ws.close(code=1, message="message1")
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1

assert not (await ws.close(code=2, message="message2"))


Expand Down Expand Up @@ -296,12 +314,15 @@
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()

await ws.write_eof()
await ws.write_eof()
await ws.write_eof()
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_eofstream_in_reader(make_request: Any, loop: Any) -> None:
Expand All @@ -327,6 +348,7 @@
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

ws._reader = mock.Mock()
res = loop.create_future()
Expand All @@ -336,6 +358,8 @@
with pytest.raises(asyncio.TimeoutError):
await ws.receive()

assert len(ws._req.transport.close.mock_calls) == 1


async def test_multiple_receive_on_close_connection(make_request: Any) -> None:
req = make_request("GET", "/")
Expand Down Expand Up @@ -367,13 +391,15 @@
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

exc = ValueError()
ws._writer = mock.Mock()
ws._writer.close.side_effect = exc
await ws.close()
assert ws.closed
assert ws.exception() is exc
assert len(ws._req.transport.close.mock_calls) == 1

ws._closed = False
ws._writer.close.side_effect = asyncio.CancelledError()
Expand Down