Skip to content

Commit

Permalink
Support lifespan state
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Mar 4, 2023
1 parent e594de7 commit ba1588a
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 113 deletions.
6 changes: 2 additions & 4 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send


class Starlette:
Expand Down Expand Up @@ -55,9 +55,7 @@ def __init__(
] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[
typing.Callable[["Starlette"], typing.AsyncContextManager]
] = None,
lifespan: typing.Optional[Lifespan] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand Down
102 changes: 41 additions & 61 deletions starlette/routing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import contextlib
import functools
import inspect
import re
import traceback
import types
import typing
import warnings
from contextlib import asynccontextmanager
from enum import Enum

from starlette._utils import is_async_callable
Expand All @@ -17,7 +14,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send, StatelessLifespan
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -530,45 +527,25 @@ def __repr__(self) -> str:
_T = typing.TypeVar("_T")


class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
def __init__(self, cm: typing.ContextManager[_T]):
self._cm = cm

async def __aenter__(self) -> _T:
return self._cm.__enter__()

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> typing.Optional[bool]:
return self._cm.__exit__(exc_type, exc_value, traceback)


def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
cmgr = contextlib.contextmanager(lifespan_context)

@functools.wraps(cmgr)
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
return _AsyncLiftContextManager(cmgr(app))

return wrapper
_TDefaultLifespan = typing.TypeVar("_TDefaultLifespan", bound="_DefaultLifespan")


class _DefaultLifespan:
def __init__(self, router: "Router"):
self._router = router

async def __aenter__(self) -> None:
await self._router.startup()
await self._router.startup(state=self._state)

async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown()

def __call__(self: _T, app: object) -> _T:
await self._router.shutdown(state=self._state)

def __call__(
self: _TDefaultLifespan,
app: object,
state: typing.Optional[typing.Dict[str, typing.Any]],
) -> _TDefaultLifespan:
self._state = state
return self


Expand All @@ -580,9 +557,7 @@ def __init__(
default: typing.Optional[ASGIApp] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[
typing.Callable[[typing.Any], typing.AsyncContextManager]
] = None,
lifespan: typing.Optional[Lifespan] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
Expand All @@ -591,27 +566,13 @@ def __init__(
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)

if lifespan is None:
self.lifespan_context: typing.Callable[
[typing.Any], typing.AsyncContextManager
] = _DefaultLifespan(self)

elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"async generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = asynccontextmanager(
lifespan, # type: ignore[arg-type]
)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = _wrap_gen_lifespan_context(
lifespan, # type: ignore[arg-type]
self.lifespan_context: Lifespan = _DefaultLifespan(self)
elif inspect.isasyncgenfunction(lifespan) or inspect.isgeneratorfunction(
lifespan
):
raise RuntimeError(
"Generator functions are not supported for lifespan, "
"use an @contextlib.asynccontextmanager function instead."
)
else:
self.lifespan_context = lifespan
Expand Down Expand Up @@ -639,21 +600,31 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
pass
raise NoMatchFound(name, path_params)

async def startup(self) -> None:
async def startup(
self, state: typing.Optional[typing.Dict[str, typing.Any]]
) -> None:
"""
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
sig = inspect.signature(handler)
if len(sig.parameters) == 1 and state is not None:
handler = functools.partial(handler, state)
if is_async_callable(handler):
await handler()
else:
handler()

async def shutdown(self) -> None:
async def shutdown(
self, state: typing.Optional[typing.Dict[str, typing.Any]]
) -> None:
"""
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
sig = inspect.signature(handler)
if len(sig.parameters) == 1 and state is not None:
handler = functools.partial(handler, state)
if is_async_callable(handler):
await handler()
else:
Expand All @@ -666,9 +637,18 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
started = False
app = scope.get("app")
state = scope.get("state")
await receive()
try:
async with self.lifespan_context(app):
lifespan_context: Lifespan
if (
len(inspect.signature(self.lifespan_context).parameters) == 2
and state is not None
):
lifespan_context = functools.partial(self.lifespan_context, state=state)
else:
lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context)
async with lifespan_context(app):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
Expand Down
8 changes: 7 additions & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,13 @@ def __init__(
portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
state: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.state = state

def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
Expand Down Expand Up @@ -243,6 +245,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"client": ["testclient", 50000],
"server": [host, port],
"subprotocols": subprotocols,
"state": self.state,
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
Expand All @@ -260,6 +263,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.state,
}

request_complete = False
Expand Down Expand Up @@ -380,11 +384,13 @@ def __init__(
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
self.app_state: typing.Dict[str, typing.Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
state=self.app_state,
)
if headers is None:
headers = {}
Expand Down Expand Up @@ -749,7 +755,7 @@ def __exit__(self, *args: typing.Any) -> None:
self.exit_stack.close()

async def lifespan(self) -> None:
scope = {"type": "lifespan"}
scope = {"type": "lifespan", "state": self.app_state}
try:
await self.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
Expand Down
6 changes: 6 additions & 0 deletions starlette/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@
Send = typing.Callable[[Message], typing.Awaitable[None]]

ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]

StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager]
StateLifespan = typing.Callable[
[typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager
]
Lifespan = typing.Union[StatelessLifespan, StateLifespan]
59 changes: 12 additions & 47 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,57 +381,22 @@ async def lifespan(app):
assert cleanup_complete


deprecated_lifespan = pytest.mark.filterwarnings(
r"ignore"
r":(async )?generator function lifespans are deprecated, use an "
r"@contextlib\.asynccontextmanager function instead"
r":DeprecationWarning"
r":starlette.routing"
)


@deprecated_lifespan
def test_app_async_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

async def lifespan(app):
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True

app = Starlette(lifespan=lifespan)

assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
async def async_gen_lifespan():
yield # pragma: no cover


@deprecated_lifespan
def test_app_sync_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

def lifespan(app):
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True
def sync__gen_lifespan():
yield # pragma: no cover

app = Starlette(lifespan=lifespan)

assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
@pytest.mark.parametrize("lifespan", [async_gen_lifespan, sync__gen_lifespan])
def test_app_gen_lifespan(lifespan):
with pytest.raises(
RuntimeError,
match="Generator functions are not supported for lifespan"
", use an @contextlib.asynccontextmanager function instead.",
):
Starlette(lifespan=lifespan)


def test_decorator_deprecations() -> None:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,40 @@ def run_shutdown():
assert shutdown_complete


def test_lifespan_with_state(test_client_factory):
startup_complete = False
shutdown_complete = False

async def hello_world(request):
assert request.state.startup
return PlainTextResponse("hello, world")

async def run_startup(state):
nonlocal startup_complete
startup_complete = True
state["startup"] = True

async def run_shutdown(state):
nonlocal shutdown_complete
shutdown_complete = True
assert state["startup"]

app = Router(
on_startup=[run_startup],
on_shutdown=[run_shutdown],
routes=[Route("/", hello_world)],
)

assert not startup_complete
assert not shutdown_complete
with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
assert startup_complete
assert shutdown_complete


def test_raise_on_startup(test_client_factory):
def run_startup():
raise RuntimeError()
Expand Down

0 comments on commit ba1588a

Please sign in to comment.