Skip to content

Commit

Permalink
Support dependencies in websocket routes
Browse files Browse the repository at this point in the history
I've been using dependencies to handle authentication.
But imagine my surprise when I realized my websocket endpoint wasn't authenticated at all?

This commit cherry-picks the `dependencies` chunks from `APIRoute` into `APIWebSocketRoute`

I also made a few minor style nit-picks
  • Loading branch information
paulo-raca committed Feb 10, 2022
1 parent b93f8a7 commit 0ac85d4
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
28 changes: 24 additions & 4 deletions fastapi/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,35 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
*,
dependencies: Optional[Sequence[Depends]] = None,
name: Optional[str] = None,

) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)
self.router.add_api_websocket_route(
path,
endpoint,
dependencies=dependencies,
name=name,
)

def websocket(
self, path: str, name: Optional[str] = None
self,
path: str,
*,
dependencies: Optional[Sequence[Depends]] = None,
name: Optional[str] = None
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
self.add_api_websocket_route(
path,
func,
dependencies=dependencies,
name=name,
)
return func

return decorator
Expand Down
49 changes: 38 additions & 11 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,20 +281,27 @@ def __init__(
path: str,
endpoint: Callable[..., Any],
*,
dependencies: Optional[Sequence[params.Depends]] = None,
name: Optional[str] = None,
dependency_overrides_provider: Optional[Any] = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.dependencies = dependencies or []
self.name = get_name(endpoint) if name is None else name
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=path, call=self.endpoint)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
)
)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)


class APIRoute(routing.Route):
Expand Down Expand Up @@ -366,10 +373,7 @@ def __init__(
self.secure_cloned_response_field = None
self.status_code = status_code
self.tags = tags or []
if dependencies:
self.dependencies = list(dependencies)
else:
self.dependencies = []
self.dependencies = dependencies or []
self.summary = summary
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
Expand Down Expand Up @@ -467,7 +471,7 @@ def __init__(
), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or []) or []
self.dependencies = dependencies or []
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.responses = responses or {}
Expand Down Expand Up @@ -516,7 +520,7 @@ def add_api_route(
current_tags = self.tags.copy()
if tags:
current_tags.extend(tags)
current_dependencies = self.dependencies.copy()
current_dependencies = list(self.dependencies)
if dependencies:
current_dependencies.extend(dependencies)
current_callbacks = self.callbacks.copy()
Expand Down Expand Up @@ -610,21 +614,39 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
self,
path: str,
endpoint: Callable[..., Any],
*,
dependencies: Optional[Sequence[params.Depends]] = None,
name: Optional[str] = None,
) -> None:
current_dependencies = list(self.dependencies)
if dependencies:
current_dependencies.extend(dependencies)

route = APIWebSocketRoute(
path,
endpoint=endpoint,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
dependencies=current_dependencies,
)
self.routes.append(route)

def websocket(
self, path: str, name: Optional[str] = None
self,
path: str,
*,
dependencies: Optional[Sequence[params.Depends]] = None,
name: Optional[str] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(path, func, name=name)
self.add_api_websocket_route(
path,
func,
name=name,
dependencies=dependencies)
return func

return decorator
Expand Down Expand Up @@ -720,8 +742,13 @@ def include_router(
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
current_dependencies: List[params.Depends] = []
if dependencies:
current_dependencies.extend(dependencies)
if route.dependencies:
current_dependencies.extend(route.dependencies)
self.add_api_websocket_route(
prefix + route.path, route.endpoint, name=route.name
prefix + route.path, route.endpoint, dependencies=current_dependencies, name=route.name
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
Expand Down

0 comments on commit 0ac85d4

Please sign in to comment.