Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for dependencies in WebSocket routes #4534

Merged
merged 12 commits into from
Jun 11, 2023
Merged
27 changes: 23 additions & 4 deletions fastapi/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,34 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
name: Optional[str] = None,
*,
dependencies: Optional[Sequence[Depends]] = None,
) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)
self.router.add_api_websocket_route(
path,
endpoint,
name=name,
dependencies=dependencies,
)

def websocket(
self, path: str, name: Optional[str] = None
self,
path: str,
name: Optional[str] = None,
*,
dependencies: Optional[Sequence[Depends]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
self.add_api_websocket_route(
path,
func,
name=name,
dependencies=dependencies,
)
return func

return decorator
Expand Down
47 changes: 38 additions & 9 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,21 @@ def __init__(
endpoint: Callable[..., Any],
*,
name: Optional[str] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
dependency_overrides_provider: Optional[Any] = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)

self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
Expand Down Expand Up @@ -416,10 +424,7 @@ def __init__(
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
if dependencies:
self.dependencies = list(dependencies)
else:
self.dependencies = []
self.dependencies = list(dependencies or [])
paulo-raca marked this conversation as resolved.
Show resolved Hide resolved
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
Expand Down Expand Up @@ -514,7 +519,7 @@ def __init__(
), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or []) or []
self.dependencies = list(dependencies or [])
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.responses = responses or {}
Expand Down Expand Up @@ -688,21 +693,37 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
name: Optional[str] = None,
*,
dependencies: Optional[Sequence[params.Depends]] = None,
) -> None:
current_dependencies = self.dependencies.copy()
if dependencies:
current_dependencies.extend(dependencies)

route = APIWebSocketRoute(
self.prefix + path,
endpoint=endpoint,
name=name,
dependencies=current_dependencies,
dependency_overrides_provider=self.dependency_overrides_provider,
)
self.routes.append(route)

def websocket(
self, path: str, name: Optional[str] = None
self,
path: str,
name: Optional[str] = None,
*,
dependencies: Optional[Sequence[params.Depends]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
self.add_api_websocket_route(
path, func, name=name, dependencies=dependencies
)
return func

return decorator
Expand Down Expand Up @@ -817,8 +838,16 @@ def include_router(
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
current_dependencies = []
if dependencies:
current_dependencies.extend(dependencies)
if route.dependencies:
current_dependencies.extend(route.dependencies)
self.add_api_websocket_route(
prefix + route.path, route.endpoint, name=route.name
prefix + route.path,
route.endpoint,
dependencies=current_dependencies,
name=route.name,
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
Expand Down
73 changes: 73 additions & 0 deletions tests/test_ws_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
from typing import List

from fastapi import APIRouter, Depends, FastAPI, WebSocket
from fastapi.testclient import TestClient
from typing_extensions import Annotated


def dependency_list() -> List[str]:
return []


DepList = Annotated[List[str], Depends(dependency_list)]


def create_dependency(name: str):
def fun(deps: DepList):
deps.append(name)

return Depends(fun)


router = APIRouter(dependencies=[create_dependency("router")])
prefix_router = APIRouter(dependencies=[create_dependency("prefix_router")])
app = FastAPI(dependencies=[create_dependency("app")])


@app.websocket("/", dependencies=[create_dependency("index")])
async def index(websocket: WebSocket, deps: DepList):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


@router.websocket("/router", dependencies=[create_dependency("routerindex")])
async def routerindex(websocket: WebSocket, deps: DepList):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


@prefix_router.websocket("/", dependencies=[create_dependency("routerprefixindex")])
async def routerprefixindex(websocket: WebSocket, deps: DepList):
await websocket.accept()
await websocket.send_text(json.dumps(deps))
await websocket.close()


app.include_router(router, dependencies=[create_dependency("router2")])
app.include_router(
prefix_router, prefix="/prefix", dependencies=[create_dependency("prefix_router2")]
)


def test_index():
client = TestClient(app)
with client.websocket_connect("/") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "index"]


def test_routerindex():
client = TestClient(app)
with client.websocket_connect("/router") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "router2", "router", "routerindex"]


def test_routerprefixindex():
client = TestClient(app)
with client.websocket_connect("/prefix/") as websocket:
data = json.loads(websocket.receive_text())
assert data == ["app", "prefix_router2", "prefix_router", "routerprefixindex"]