Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support lifespan state #2060

Merged
merged 7 commits into from
Mar 5, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 2 additions & 4 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send


class Starlette:
Expand Down Expand Up @@ -55,9 +55,7 @@ def __init__(
] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[
typing.Callable[["Starlette"], typing.AsyncContextManager]
] = None,
lifespan: typing.Optional[Lifespan] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand Down
52 changes: 37 additions & 15 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send, StatelessLifespan
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -558,17 +558,25 @@ def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
return wrapper


_TDefaultLifespan = typing.TypeVar("_TDefaultLifespan", bound="_DefaultLifespan")


class _DefaultLifespan:
def __init__(self, router: "Router"):
self._router = router

async def __aenter__(self) -> None:
await self._router.startup()
await self._router.startup(state=self._state)

async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown()

def __call__(self: _T, app: object) -> _T:
await self._router.shutdown(state=self._state)

def __call__(
self: _TDefaultLifespan,
app: object,
state: typing.Optional[typing.Dict[str, typing.Any]],
) -> _TDefaultLifespan:
self._state = state
return self
Comment on lines +574 to 580
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first glance I'm a bit confused about this and the need for a TypeVar.

Copy link
Sponsor Member Author

@Kludex Kludex Mar 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for the TypeVar is that it needed to be bound, as the _state was introduced here, since the self is annotated.



Expand All @@ -580,9 +588,7 @@ def __init__(
default: typing.Optional[ASGIApp] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[
typing.Callable[[typing.Any], typing.AsyncContextManager]
] = None,
lifespan: typing.Optional[Lifespan] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
Expand All @@ -591,10 +597,7 @@ def __init__(
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)

if lifespan is None:
self.lifespan_context: typing.Callable[
[typing.Any], typing.AsyncContextManager
] = _DefaultLifespan(self)

self.lifespan_context: Lifespan = _DefaultLifespan(self)
elif inspect.isasyncgenfunction(lifespan):
Kludex marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
"async generator function lifespans are deprecated, "
Expand Down Expand Up @@ -639,21 +642,31 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
pass
raise NoMatchFound(name, path_params)

async def startup(self) -> None:
async def startup(
self, state: typing.Optional[typing.Dict[str, typing.Any]]
) -> None:
"""
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
sig = inspect.signature(handler)
if len(sig.parameters) == 1 and state is not None:
Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the state is None, it means that the server doesn't support it, so we'll maintain the same error message.

handler = functools.partial(handler, state)
if is_async_callable(handler):
await handler()
else:
handler()

async def shutdown(self) -> None:
async def shutdown(
self, state: typing.Optional[typing.Dict[str, typing.Any]]
) -> None:
"""
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
sig = inspect.signature(handler)
if len(sig.parameters) == 1 and state is not None:
handler = functools.partial(handler, state)
if is_async_callable(handler):
await handler()
else:
Expand All @@ -666,9 +679,18 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
started = False
app = scope.get("app")
state = scope.get("state")
await receive()
try:
async with self.lifespan_context(app):
lifespan_context: Lifespan
if (
len(inspect.signature(self.lifespan_context).parameters) == 2
and state is not None
):
lifespan_context = functools.partial(self.lifespan_context, state=state)
else:
lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context)
async with lifespan_context(app):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
Expand Down
9 changes: 8 additions & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,14 @@ def __init__(
portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
*,
app_state: typing.Dict[str, typing.Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state

def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
Expand Down Expand Up @@ -243,6 +246,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"client": ["testclient", 50000],
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
Expand All @@ -260,6 +264,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
}

request_complete = False
Expand Down Expand Up @@ -380,11 +385,13 @@ def __init__(
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
self.app_state: typing.Dict[str, typing.Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
)
if headers is None:
headers = {}
Expand Down Expand Up @@ -749,7 +756,7 @@ def __exit__(self, *args: typing.Any) -> None:
self.exit_stack.close()

async def lifespan(self) -> None:
scope = {"type": "lifespan"}
scope = {"type": "lifespan", "state": self.app_state}
try:
await self.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
Expand Down
6 changes: 6 additions & 0 deletions starlette/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@
Send = typing.Callable[[Message], typing.Awaitable[None]]

ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]

StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager]
StateLifespan = typing.Callable[
[typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager
]
Lifespan = typing.Union[StatelessLifespan, StateLifespan]
Kludex marked this conversation as resolved.
Show resolved Hide resolved
47 changes: 47 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,53 @@ def run_shutdown():
assert shutdown_complete


def test_lifespan_with_state(test_client_factory):
startup_complete = False
shutdown_complete = False

async def hello_world(request):
# modifications to the state should not leak across requests
assert request.state.count == 0
# modify the state, this should not leak to the lifespan or other requests
request.state.count += 1
# since state.list is a mutable object this modification _will_ leak across
# requests and to the lifespan
request.state.list.append(1)
return PlainTextResponse("hello, world")

async def run_startup(state):
nonlocal startup_complete
startup_complete = True
state["count"] = 0
state["list"] = []

async def run_shutdown(state):
nonlocal shutdown_complete
shutdown_complete = True
# modifications made to the state from a request do not leak to the lifespan
assert state["count"] == 0
# unless of course the request mutates a mutable object that is referenced
# via state
assert state["list"] == [1, 1]
Kludex marked this conversation as resolved.
Show resolved Hide resolved

app = Router(
on_startup=[run_startup],
on_shutdown=[run_shutdown],
routes=[Route("/", hello_world)],
)

assert not startup_complete
assert not shutdown_complete
with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
# Calling it a second time to ensure that the state is preserved.
client.get("/")
assert startup_complete
assert shutdown_complete


def test_raise_on_startup(test_client_factory):
def run_startup():
raise RuntimeError()
Expand Down