Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add exception handler for WebSocketRequestValidationError (which also allows to override it) #6030

Merged
merged 10 commits into from
Jun 11, 2023
8 changes: 7 additions & 1 deletion fastapi/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from fastapi.exception_handlers import (
http_exception_handler,
request_validation_exception_handler,
websocket_request_validation_exception_handler,
)
from fastapi.exceptions import RequestValidationError
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.logger import logger
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.docs import (
Expand Down Expand Up @@ -145,6 +146,11 @@ def __init__(
self.exception_handlers.setdefault(
RequestValidationError, request_validation_exception_handler
)
self.exception_handlers.setdefault(
WebSocketRequestValidationError,
# Starlette still has incorrect type specification for the handlers
websocket_request_validation_exception_handler, # type: ignore
)

self.user_middleware: List[Middleware] = (
[] if middleware is None else list(middleware)
Expand Down
13 changes: 11 additions & 2 deletions fastapi/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.utils import is_body_allowed_for_status_code
from fastapi.websockets import WebSocket
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION


async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
Expand All @@ -23,3 +24,11 @@ async def request_validation_exception_handler(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": jsonable_encoder(exc.errors())},
)


async def websocket_request_validation_exception_handler(
websocket: WebSocket, exc: WebSocketRequestValidationError
) -> None:
await websocket.close(
code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors())
)
2 changes: 0 additions & 2 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
request_response,
websocket_session,
)
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket

Expand Down Expand Up @@ -283,7 +282,6 @@ async def app(websocket: WebSocket) -> None:
)
values, errors, _, _2, _3 = solved_result
if errors:
await websocket.close(code=WS_1008_POLICY_VIOLATION)
raise WebSocketRequestValidationError(errors)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values)
Expand Down
152 changes: 148 additions & 4 deletions tests/test_ws_router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from fastapi import APIRouter, Depends, FastAPI, WebSocket
import functools

import pytest
from fastapi import (
APIRouter,
Depends,
FastAPI,
Header,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi.middleware import Middleware
from fastapi.testclient import TestClient

router = APIRouter()
Expand Down Expand Up @@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket):
await websocket.close()


app.include_router(router)
app.include_router(prefix_router, prefix="/prefix")
app.include_router(native_prefix_route)
async def ws_dependency_err():
raise NotImplementedError()


@router.websocket("/depends-err/")
async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
pass # pragma: no cover


async def ws_dependency_validate(x_missing: str = Header()):
pass # pragma: no cover


@router.websocket("/depends-validate/")
async def router_ws_depends_validate(
websocket: WebSocket, data=Depends(ws_dependency_validate)
):
pass # pragma: no cover


class CustomError(Exception):
pass


@router.websocket("/custom_error/")
async def router_ws_custom_error(websocket: WebSocket):
raise CustomError()


def make_app(app=None, **kwargs):
app = app or FastAPI(**kwargs)
app.include_router(router)
app.include_router(prefix_router, prefix="/prefix")
app.include_router(native_prefix_route)
return app


app = make_app(app)


def test_app():
Expand Down Expand Up @@ -125,3 +172,100 @@ def test_router_with_params():
assert data == "path/to/file"
data = websocket.receive_text()
assert data == "a_query_param"


def test_wrong_uri():
"""
Verify that a websocket connection to a non-existent endpoing returns in a shutdown
"""
client = TestClient(app)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/no-router/"):
pass # pragma: no cover
assert e.value.code == status.WS_1000_NORMAL_CLOSURE


def websocket_middleware(middleware_func):
"""
Helper to create a Starlette pure websocket middleware
"""

def middleware_constructor(app):
@functools.wraps(app)
async def wrapped_app(scope, receive, send):
if scope["type"] != "websocket":
return await app(scope, receive, send) # pragma: no cover

async def call_next():
return await app(scope, receive, send)

websocket = WebSocket(scope, receive=receive, send=send)
return await middleware_func(websocket, call_next)

return wrapped_app

return middleware_constructor


def test_depend_validation():
"""
Verify that a validation in a dependency invokes the correct exception handler
"""
caught = []

@websocket_middleware
async def catcher(websocket, call_next):
try:
return await call_next()
except Exception as e: # pragma: no cover
caught.append(e)
raise

myapp = make_app(middleware=[Middleware(catcher)])

client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/depends-validate/"):
pass # pragma: no cover
# the validation error does produce a close message
assert e.value.code == status.WS_1008_POLICY_VIOLATION
# and no error is leaked
assert caught == []


def test_depend_err_middleware():
"""
Verify that it is possible to write custom WebSocket middleware to catch errors
"""

@websocket_middleware
async def errorhandler(websocket: WebSocket, call_next):
try:
return await call_next()
except Exception as e:
await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))

myapp = make_app(middleware=[Middleware(errorhandler)])
client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/depends-err/"):
pass # pragma: no cover
assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
assert "NotImplementedError" in e.value.reason


def test_depend_err_handler():
"""
Verify that it is possible to write custom WebSocket middleware to catch errors
"""

async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
await websocket.close(1002, "foo")

myapp = make_app(exception_handlers={CustomError: custom_handler})
client = TestClient(myapp)
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/custom_error/"):
pass # pragma: no cover
assert e.value.code == 1002
assert "foo" in e.value.reason