Skip to content

Commit

Permalink
Deliver a websocket.disconnect message to the app even if it closes/r…
Browse files Browse the repository at this point in the history
…ejects itself.
  • Loading branch information
kristjanvalur committed Mar 22, 2023
1 parent 39ba68c commit 8810e35
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
15 changes: 15 additions & 0 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
self._send = send
self.client_state = WebSocketState.CONNECTING
self.application_state = WebSocketState.CONNECTING
self.app_disconnect_msg: typing.Optional[Message] = None

def _have_response_extension(self) -> bool:
return "websocket.http.response" in self.scope.get("extensions", {})
Expand All @@ -36,6 +37,11 @@ async def receive(self) -> Message:
"""
Receive ASGI websocket messages, ensuring valid state transitions.
"""
if self.app_disconnect_msg is not None:
# return message which resulted from app disconnect
msg = self.app_disconnect_msg
self.app_disconnect_msg = None
return msg
if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
Expand All @@ -56,6 +62,8 @@ async def receive(self) -> Message:
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
if "code" not in message:
message["code"] = 1005 # websocket spec
return message
else:
raise RuntimeError(
Expand All @@ -80,6 +88,8 @@ async def send(self, message: Message) -> None:
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
# no close frame is sent, then the default is 1006
self.app_disconnect_msg = {"type": "websocket.disconnect", "code": 1006}
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
Expand All @@ -94,6 +104,10 @@ async def send(self, message: Message) -> None:
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
self.app_disconnect_msg = {
"type": "websocket.disconnect",
"code": message.get("code", 1000),
}
await self._send(message)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
Expand All @@ -104,6 +118,7 @@ async def send(self, message: Message) -> None:
)
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
self.app_disconnect_msg = {"type": "websocket.disconnect", "code": 1006}
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
Expand Down
36 changes: 34 additions & 2 deletions tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from starlette import status
from starlette.responses import Response
from starlette.testclient import WebSocketReject
from starlette.types import Receive, Scope, Send
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState


Expand Down Expand Up @@ -226,63 +226,95 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:


def test_application_close(test_client_factory):
close_msg: Message = {}

async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal close_msg
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.close(status.WS_1001_GOING_AWAY)
close_msg = await websocket.receive()

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.code == status.WS_1001_GOING_AWAY
assert close_msg == {
"type": "websocket.disconnect",
"code": status.WS_1001_GOING_AWAY,
}


def test_rejected_connection(test_client_factory):
close_msg: Message = {}

async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal close_msg
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
await websocket.close(status.WS_1001_GOING_AWAY)
close_msg = await websocket.receive()

client = test_client_factory(app)
with pytest.raises(WebSocketReject) as exc:
with client.websocket_connect("/"):
pass # pragma: nocover
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.response_status == 403
assert close_msg == {
"type": "websocket.disconnect",
"code": status.WS_1006_ABNORMAL_CLOSURE,
}


def test_send_response(test_client_factory):
close_msg: Message = {}

async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal close_msg
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
await websocket.send_response(response)
close_msg = await websocket.receive()

client = test_client_factory(app)
with pytest.raises(WebSocketReject) as exc:
with client.websocket_connect("/"):
pass # pragma: nocover
assert exc.value.response_status == 404
assert exc.value.response_body == b"foo"
assert close_msg == {
"type": "websocket.disconnect",
"code": status.WS_1006_ABNORMAL_CLOSURE,
}


def test_send_response_unsupported(test_client_factory):
close_msg: Message = {}

async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal close_msg
del scope["extensions"]["websocket.http.response"]
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
await websocket.send_response(response)
close_msg = await websocket.receive()

client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect) as exc:
with client.websocket_connect("/"):
pass # pragma: nocover
assert exc.value.code == status.WS_1008_POLICY_VIOLATION
assert close_msg == {
"type": "websocket.disconnect",
"code": status.WS_1006_ABNORMAL_CLOSURE,
}


def test_send_response_invalid(test_client_factory):
Expand Down

0 comments on commit 8810e35

Please sign in to comment.