Skip to content

Commit

Permalink
Fix websocket connection leak (#7978)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Dec 18, 2023
1 parent 5e44ba4 commit 6f1c608
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGES/7978.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix websocket connection leak
93 changes: 53 additions & 40 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ def _send_heartbeat(self) -> None:
def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._closed = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = asyncio.TimeoutError()
self._req.transport.close()

async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
# make pre-check to don't hide it by do_handshake() exceptions
Expand Down Expand Up @@ -382,7 +381,10 @@ async def write_eof(self) -> None: # type: ignore[override]
await self.close()
self._eof_sent = True

async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
async def close(
self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
) -> bool:
"""Close websocket connection."""
if self._writer is None:
raise RuntimeError("Call .prepare() first")

Expand All @@ -396,46 +398,53 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting

if not self._closed:
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
return True
if self._closed:
return False

if self._closing:
return True
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
if drain:
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True

reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
return True
if self._closing:
return True

if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
return True
reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
except asyncio.CancelledError:
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True

self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = asyncio.TimeoutError()
if msg.type == WSMsgType.CLOSE:
self._set_code_close_transport(msg.data)
return True
else:
return False

self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = asyncio.TimeoutError()
return True

def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
self._close_code = code
if self._req is not None and self._req.transport is not None:
self._req.transport.close()

async def receive(self, timeout: Optional[float] = None) -> WSMessage:
if self._reader is None:
Expand Down Expand Up @@ -466,7 +475,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
set_result(waiter, True)
self._waiting = None
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except EofStream:
self._close_code = WSCloseCode.OK
Expand All @@ -488,7 +497,11 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
self._close_code = msg.data
# Could be closed while awaiting reader.
if not self._closed and self._autoclose: # type: ignore[redundant-expr]
await self.close()
# The client is likely going to close the
# connection out from under us so we do not
# want to drain any pending writes as it will
# likely result writing to a broken pipe.
await self.close(drain=False)
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
Expand Down
12 changes: 11 additions & 1 deletion docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,14 @@ and :ref:`aiohttp-web-signals` handlers::

.. versionadded:: 3.3

:param bool autoclose: Close connection when the client sends
a :const:`~aiohttp.WSMsgType.CLOSE` message,
``True`` by default. If set to ``False``,
the connection is not closed and the
caller is responsible for calling
``request.transport.close()`` to avoid
leaking resources.


The class supports ``async for`` statement for iterating over
incoming messages::
Expand Down Expand Up @@ -1164,7 +1172,7 @@ and :ref:`aiohttp-web-signals` handlers::
The method is converted into :term:`coroutine`,
*compress* parameter added.

.. method:: close(*, code=WSCloseCode.OK, message=b'')
.. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True)
:async:

A :ref:`coroutine<coroutine>` that initiates closing
Expand All @@ -1178,6 +1186,8 @@ and :ref:`aiohttp-web-signals` handlers::
:class:`str` (converted to *UTF-8* encoded bytes)
or :class:`bytes`.

:param bool drain: drain outgoing buffer before closing connection.

:raise RuntimeError: if connection is not started

.. method:: receive(timeout=None)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore
import asyncio
import time
from typing import Any
from unittest import mock

Expand Down Expand Up @@ -139,6 +140,20 @@ async def test_write_non_prepared() -> None:
await ws.write(b"data")


async def test_heartbeat_timeout(make_request: Any) -> None:
"""Verify the transport is closed when the heartbeat timeout is reached."""
loop = asyncio.get_running_loop()
future = loop.create_future()
req = make_request("GET", "/")
lowest_time = time.get_clock_info("monotonic").resolution
req._protocol._timeout_ceil_threshold = lowest_time
ws = WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time)
await ws.prepare(req)
ws._req.transport.close.side_effect = lambda: future.set_result(None)
await future
assert ws.closed


def test_websocket_ready() -> None:
websocket_ready = WebSocketReady(True, "chat")
assert websocket_ready.ok is True
Expand Down Expand Up @@ -207,6 +222,7 @@ async def test_send_str_closed(make_request: Any) -> None:
await ws.prepare(req)
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()
assert len(ws._req.transport.close.mock_calls) == 1

with pytest.raises(ConnectionError):
await ws.send_str("string")
Expand Down Expand Up @@ -263,6 +279,8 @@ async def test_close_idempotent(make_request: Any) -> None:
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
assert await ws.close(code=1, message="message1")
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1

assert not (await ws.close(code=2, message="message2"))


Expand Down Expand Up @@ -296,12 +314,15 @@ async def test_write_eof_idempotent(make_request: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()

await ws.write_eof()
await ws.write_eof()
await ws.write_eof()
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_eofstream_in_reader(make_request: Any, loop: Any) -> None:
Expand All @@ -327,6 +348,7 @@ async def test_receive_timeouterror(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

ws._reader = mock.Mock()
res = loop.create_future()
Expand All @@ -336,6 +358,8 @@ async def test_receive_timeouterror(make_request: Any, loop: Any) -> None:
with pytest.raises(asyncio.TimeoutError):
await ws.receive()

assert len(ws._req.transport.close.mock_calls) == 1


async def test_multiple_receive_on_close_connection(make_request: Any) -> None:
req = make_request("GET", "/")
Expand Down Expand Up @@ -367,13 +391,15 @@ async def test_close_exc(make_request: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

exc = ValueError()
ws._writer = mock.Mock()
ws._writer.close.side_effect = exc
await ws.close()
assert ws.closed
assert ws.exception() is exc
assert len(ws._req.transport.close.mock_calls) == 1

ws._closed = False
ws._writer.close.side_effect = asyncio.CancelledError()
Expand Down

0 comments on commit 6f1c608

Please sign in to comment.