Skip to content

Commit

Permalink
Fix websocket connection leak (#7978) (#7980)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Dec 19, 2023
1 parent f1cee99 commit 477b237
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGES/7978.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix websocket connection leak
94 changes: 54 additions & 40 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ def _send_heartbeat(self) -> None:
def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._closed = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = asyncio.TimeoutError()
self._req.transport.close()

async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
# make pre-check to don't hide it by do_handshake() exceptions
Expand Down Expand Up @@ -360,7 +359,10 @@ async def write_eof(self) -> None: # type: ignore[override]
await self.close()
self._eof_sent = True

async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
async def close(
self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
) -> bool:
"""Close websocket connection."""
if self._writer is None:
raise RuntimeError("Call .prepare() first")

Expand All @@ -374,46 +376,53 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting

if not self._closed:
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
return True
if self._closed:
return False

if self._closing:
return True
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
if drain:
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True

reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
return True
if self._closing:
return True

if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
return True
reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
except asyncio.CancelledError:
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True

self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = asyncio.TimeoutError()
if msg.type == WSMsgType.CLOSE:
self._set_code_close_transport(msg.data)
return True
else:
return False

self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = asyncio.TimeoutError()
return True

def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
self._close_code = code
if self._req is not None and self._req.transport is not None:
self._req.transport.close()

async def receive(self, timeout: Optional[float] = None) -> WSMessage:
if self._reader is None:
Expand Down Expand Up @@ -444,7 +453,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
set_result(waiter, True)
self._waiting = None
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except EofStream:
self._close_code = WSCloseCode.OK
Expand All @@ -464,8 +473,13 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
# Could be closed while awaiting reader.
if not self._closed and self._autoclose:
await self.close()
# The client is likely going to close the
# connection out from under us so we do not
# want to drain any pending writes as it will
# likely result writing to a broken pipe.
await self.close(drain=False)
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
Expand Down
12 changes: 11 additions & 1 deletion docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,14 @@ and :ref:`aiohttp-web-signals` handlers::

.. versionadded:: 3.3

:param bool autoclose: Close connection when the client sends
a :const:`~aiohttp.WSMsgType.CLOSE` message,
``True`` by default. If set to ``False``,
the connection is not closed and the
caller is responsible for calling
``request.transport.close()`` to avoid
leaking resources.


The class supports ``async for`` statement for iterating over
incoming messages::
Expand Down Expand Up @@ -1146,7 +1154,7 @@ and :ref:`aiohttp-web-signals` handlers::
The method is converted into :term:`coroutine`,
*compress* parameter added.

.. method:: close(*, code=WSCloseCode.OK, message=b'')
.. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True)
:async:

A :ref:`coroutine<coroutine>` that initiates closing
Expand All @@ -1160,6 +1168,8 @@ and :ref:`aiohttp-web-signals` handlers::
:class:`str` (converted to *UTF-8* encoded bytes)
or :class:`bytes`.

:param bool drain: drain outgoing buffer before closing connection.

:raise RuntimeError: if connection is not started

.. method:: receive(timeout=None)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import time
from typing import Any
from unittest import mock

import aiosignal
Expand Down Expand Up @@ -165,6 +167,20 @@ async def test_write_non_prepared() -> None:
await ws.write(b"data")


async def test_heartbeat_timeout(make_request: Any) -> None:
"""Verify the transport is closed when the heartbeat timeout is reached."""
loop = asyncio.get_running_loop()
future = loop.create_future()
req = make_request("GET", "/")
lowest_time = time.get_clock_info("monotonic").resolution
req._protocol._timeout_ceil_threshold = lowest_time
ws = WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time)
await ws.prepare(req)
ws._req.transport.close.side_effect = lambda: future.set_result(None)
await future
assert ws.closed


def test_websocket_ready() -> None:
websocket_ready = WebSocketReady(True, "chat")
assert websocket_ready.ok is True
Expand Down Expand Up @@ -233,6 +249,7 @@ async def test_send_str_closed(make_request) -> None:
await ws.prepare(req)
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()
assert len(ws._req.transport.close.mock_calls) == 1

with pytest.raises(ConnectionError):
await ws.send_str("string")
Expand Down Expand Up @@ -289,6 +306,8 @@ async def test_close_idempotent(make_request) -> None:
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
assert await ws.close(code=1, message="message1")
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1

assert not (await ws.close(code=2, message="message2"))


Expand Down Expand Up @@ -322,12 +341,15 @@ async def test_write_eof_idempotent(make_request) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()

await ws.write_eof()
await ws.write_eof()
await ws.write_eof()
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_eofstream_in_reader(make_request, loop) -> None:
Expand All @@ -353,6 +375,7 @@ async def test_receive_timeouterror(make_request, loop) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

ws._reader = mock.Mock()
res = loop.create_future()
Expand All @@ -362,6 +385,8 @@ async def test_receive_timeouterror(make_request, loop) -> None:
with pytest.raises(asyncio.TimeoutError):
await ws.receive()

assert len(ws._req.transport.close.mock_calls) == 1


async def test_multiple_receive_on_close_connection(make_request) -> None:
req = make_request("GET", "/")
Expand Down Expand Up @@ -394,13 +419,15 @@ async def test_close_exc(make_request) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert len(ws._req.transport.close.mock_calls) == 0

exc = ValueError()
ws._writer = mock.Mock()
ws._writer.close.side_effect = exc
await ws.close()
assert ws.closed
assert ws.exception() is exc
assert len(ws._req.transport.close.mock_calls) == 1

ws._closed = False
ws._writer.close.side_effect = asyncio.CancelledError()
Expand Down

0 comments on commit 477b237

Please sign in to comment.