Skip to content

Commit

Permalink
✨ Add support for dependencies in WebSocket routes (#4534)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
  • Loading branch information
paulo-raca and tiangolo committed Jun 11, 2023
1 parent ee96a09 commit d8b8f21
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 13 deletions.
27 changes: 23 additions & 4 deletions fastapi/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,34 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
name: Optional[str] = None,
*,
dependencies: Optional[Sequence[Depends]] = None,
) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)
self.router.add_api_websocket_route(
path,
endpoint,
name=name,
dependencies=dependencies,
)

def websocket(
self, path: str, name: Optional[str] = None
self,
path: str,
name: Optional[str] = None,
*,
dependencies: Optional[Sequence[Depends]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
self.add_api_websocket_route(
path,
func,
name=name,
dependencies=dependencies,
)
return func

return decorator
Expand Down
47 changes: 38 additions & 9 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,21 @@ def __init__(
endpoint: Callable[..., Any],
*,
name: Optional[str] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
dependency_overrides_provider: Optional[Any] = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)

self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
Expand Down Expand Up @@ -416,10 +424,7 @@ def __init__(
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
if dependencies:
self.dependencies = list(dependencies)
else:
self.dependencies = []
self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
Expand Down Expand Up @@ -514,7 +519,7 @@ def __init__(
), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or []) or []
self.dependencies = list(dependencies or [])
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.responses = responses or {}
Expand Down Expand Up @@ -688,21 +693,37 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
name: Optional[str] = None,
*,
dependencies: Optional[Sequence[params.Depends]] = None,
) -> None:
current_dependencies = self.dependencies.copy()
if dependencies:
current_dependencies.extend(dependencies)

route = APIWebSocketRoute(
self.prefix + path,
endpoint=endpoint,
name=name,
dependencies=current_dependencies,
dependency_overrides_provider=self.dependency_overrides_provider,
)
self.routes.append(route)

def websocket(
self, path: str, name: Optional[str] = None
self,
path: str,
name: Optional[str] = None,
*,
dependencies: Optional[Sequence[params.Depends]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
self.add_api_websocket_route(
path, func, name=name, dependencies=dependencies
)
return func

return decorator
Expand Down Expand Up @@ -817,8 +838,16 @@ def include_router(
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
current_dependencies = []
if dependencies:
current_dependencies.extend(dependencies)
if route.dependencies:
current_dependencies.extend(route.dependencies)
self.add_api_websocket_route(
prefix + route.path, route.endpoint, name=route.name
prefix + route.path,
route.endpoint,
dependencies=current_dependencies,
name=route.name,
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
Expand Down
73 changes: 73 additions & 0 deletions tests/test_ws_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
from typing import List

from fastapi import APIRouter, Depends, FastAPI, WebSocket
from fastapi.testclient import TestClient
from typing_extensions import Annotated


def dependency_list() -> List[str]:
return []


DepList = Annotated[List[str], Depends(dependency_list)]


def create_dependency(name: str):
def fun(deps: DepList):
deps.append(name)

return Depends(fun)


router = APIRouter(dependencies=[create_dependency("router")])
prefix_router = APIRouter(dependencies=[create_dependency("prefix_router")])
app = FastAPI(dependencies=[create_dependency("app")])


@app.websocket("/", dependencies=[create_dependency("index")])
async def index(websocket: WebSocket, deps: DepList):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


@router.websocket("/router", dependencies=[create_dependency("routerindex")])
async def routerindex(websocket: WebSocket, deps: DepList):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


@prefix_router.websocket("/", dependencies=[create_dependency("routerprefixindex")])
async def routerprefixindex(websocket: WebSocket, deps: DepList):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


app.include_router(router, dependencies=[create_dependency("router2")])
app.include_router(
prefix_router, prefix="/prefix", dependencies=[create_dependency("prefix_router2")]
)


def test_index():
client = TestClient(app)
with client.websocket_connect("/") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "index"]


def test_routerindex():
client = TestClient(app)
with client.websocket_connect("/router") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "router2", "router", "routerindex"]


def test_routerprefixindex():
client = TestClient(app)
with client.websocket_connect("/prefix/") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "prefix_router2", "prefix_router", "routerprefixindex"]

0 comments on commit d8b8f21

Please sign in to comment.