Skip to content

Commit

Permalink
websocket: Add warning if client connection isn't closed cleanly
Browse files Browse the repository at this point in the history
This gives a warning that is not dependent on GC for the issue
in tornadoweb#3257. This new warning covers all websocket client connections,
while the previous GC-dependent warning only affected those with
ping_interval set. This unfortunately introduces an effective
requirement to close all websocket clients explicitly for those
who are strict about warnings.
  • Loading branch information
bdarnell committed May 7, 2023
1 parent e0fa53e commit 413281e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 23 deletions.
57 changes: 34 additions & 23 deletions tornado/test/websocket_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import functools
import socket
import traceback
Expand Down Expand Up @@ -213,11 +214,21 @@ def open(self):


class WebSocketBaseTestCase(AsyncHTTPTestCase):
def setUp(self):
super().setUp()
self.conns_to_close = []

def tearDown(self):
for conn in self.conns_to_close:
conn.close()
super().tearDown()

@gen.coroutine
def ws_connect(self, path, **kwargs):
ws = yield websocket_connect(
"ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
)
self.conns_to_close.append(ws)
raise gen.Return(ws)


Expand Down Expand Up @@ -397,39 +408,39 @@ def test_websocket_network_fail(self):

@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
ws.write_message("hello")
ws.write_message("world")
# Close the underlying stream.
ws.stream.close()
with contextlib.closing((yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port()))) as ws:
ws.write_message("hello")
ws.write_message("world")
# Close the underlying stream.
ws.stream.close()

@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
with contextlib.closing((yield websocket_connect(
HTTPRequest(
"ws://127.0.0.1:%d/header" % self.get_http_port(),
headers={"X-Test": "hello"},
)
)
response = yield ws.read_message()
self.assertEqual(response, "hello")
))) as ws:
response = yield ws.read_message()
self.assertEqual(response, "hello")

@gen_test
def test_websocket_header_echo(self):
# Ensure that headers can be returned in the response.
# Specifically, that arbitrary headers passed through websocket_connect
# can be returned.
ws = yield websocket_connect(
with contextlib.closing((yield websocket_connect(
HTTPRequest(
"ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
headers={"X-Test-Hello": "hello"},
)
)
self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
self.assertEqual(
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
)
))) as ws:
self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
self.assertEqual(
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
)

@gen_test
def test_server_close_reason(self):
Expand Down Expand Up @@ -495,10 +506,10 @@ def test_check_origin_valid_no_path(self):
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d" % port}

ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
with contextlib.closing((yield websocket_connect(HTTPRequest(url, headers=headers)))) as ws:
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")

@gen_test
def test_check_origin_valid_with_path(self):
Expand All @@ -507,10 +518,10 @@ def test_check_origin_valid_with_path(self):
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d/something" % port}

ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
with contextlib.closing((yield websocket_connect(HTTPRequest(url, headers=headers)))) as ws:
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")

@gen_test
def test_check_origin_invalid_partial_url(self):
Expand Down
10 changes: 10 additions & 0 deletions tornado/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import struct
import tornado
from urllib.parse import urlparse
import warnings
import zlib

from tornado.concurrent import Future, future_set_result_unless_cancelled
Expand Down Expand Up @@ -1410,6 +1411,15 @@ def __init__(
104857600,
)

def __del__(self) -> None:
if self.protocol is not None:
# Unclosed client connections can sometimes log "task was destroyed but
# was pending" warnings if shutdown strikes at the wrong time (such as
# while a ping is being processed due to ping_interval). Log our own
# warning to make it a little more deterministic (although it's still
# dependent on GC timing).
warnings.warn("Unclosed WebSocketClientConnection", ResourceWarning)

def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None:
"""Closes the websocket connection.
Expand Down

0 comments on commit 413281e

Please sign in to comment.