Skip to content

Commit

Permalink
Revert "Support lifespan state (encode#2060)"
Browse files Browse the repository at this point in the history
This reverts commit da6461b.
  • Loading branch information
adriangb committed Mar 5, 2023
1 parent 5472c44 commit 44159d0
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 172 deletions.
6 changes: 4 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send


class Starlette:
Expand Down Expand Up @@ -55,7 +55,9 @@ def __init__(
] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[Lifespan] = None,
lifespan: typing.Optional[
typing.Callable[["Starlette"], typing.AsyncContextManager]
] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand Down
57 changes: 15 additions & 42 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send, StatelessLifespan
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -558,25 +558,17 @@ def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
return wrapper


_TDefaultLifespan = typing.TypeVar("_TDefaultLifespan", bound="_DefaultLifespan")


class _DefaultLifespan:
def __init__(self, router: "Router"):
self._router = router

async def __aenter__(self) -> None:
await self._router.startup(state=self._state)
await self._router.startup()

async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown(state=self._state)

def __call__(
self: _TDefaultLifespan,
app: object,
state: typing.Optional[typing.Dict[str, typing.Any]],
) -> _TDefaultLifespan:
self._state = state
await self._router.shutdown()

def __call__(self: _T, app: object) -> _T:
return self


Expand All @@ -588,7 +580,9 @@ def __init__(
default: typing.Optional[ASGIApp] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[Lifespan] = None,
lifespan: typing.Optional[
typing.Callable[[typing.Any], typing.AsyncContextManager]
] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
Expand All @@ -597,7 +591,10 @@ def __init__(
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)

if lifespan is None:
self.lifespan_context: Lifespan = _DefaultLifespan(self)
self.lifespan_context: typing.Callable[
[typing.Any], typing.AsyncContextManager
] = _DefaultLifespan(self)

elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"async generator function lifespans are deprecated, "
Expand Down Expand Up @@ -642,31 +639,21 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
pass
raise NoMatchFound(name, path_params)

async def startup(
self, state: typing.Optional[typing.Dict[str, typing.Any]]
) -> None:
async def startup(self) -> None:
"""
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
sig = inspect.signature(handler)
if len(sig.parameters) == 1 and state is not None:
handler = functools.partial(handler, state)
if is_async_callable(handler):
await handler()
else:
handler()

async def shutdown(
self, state: typing.Optional[typing.Dict[str, typing.Any]]
) -> None:
async def shutdown(self) -> None:
"""
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
sig = inspect.signature(handler)
if len(sig.parameters) == 1 and state is not None:
handler = functools.partial(handler, state)
if is_async_callable(handler):
await handler()
else:
Expand All @@ -679,23 +666,9 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
started = False
app = scope.get("app")
state = scope.get("state")
await receive()
lifespan_needs_state = (
len(inspect.signature(self.lifespan_context).parameters) == 2
)
server_supports_state = state is not None
if lifespan_needs_state and not server_supports_state:
raise RuntimeError(
'The server does not support "state" in the lifespan scope.'
)
try:
lifespan_context: Lifespan
if lifespan_needs_state:
lifespan_context = functools.partial(self.lifespan_context, state=state)
else:
lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context)
async with lifespan_context(app):
async with self.lifespan_context(app):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
Expand Down
9 changes: 1 addition & 8 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,11 @@ def __init__(
portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
*,
app_state: typing.Dict[str, typing.Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state

def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
Expand Down Expand Up @@ -246,7 +243,6 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"client": ["testclient", 50000],
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
Expand All @@ -264,7 +260,6 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
}

request_complete = False
Expand Down Expand Up @@ -385,13 +380,11 @@ def __init__(
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
self.app_state: typing.Dict[str, typing.Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
)
if headers is None:
headers = {}
Expand Down Expand Up @@ -756,7 +749,7 @@ def __exit__(self, *args: typing.Any) -> None:
self.exit_stack.close()

async def lifespan(self) -> None:
scope = {"type": "lifespan", "state": self.app_state}
scope = {"type": "lifespan"}
try:
await self.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
Expand Down
6 changes: 0 additions & 6 deletions starlette/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,3 @@
Send = typing.Callable[[Message], typing.Awaitable[None]]

ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]

StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager[typing.Any]]
StateLifespan = typing.Callable[
[typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager[typing.Any]
]
Lifespan = typing.Union[StatelessLifespan, StateLifespan]
114 changes: 0 additions & 114 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import functools
import typing
import uuid
Expand Down Expand Up @@ -670,119 +669,6 @@ def run_shutdown():
assert shutdown_complete


def test_lifespan_with_state(test_client_factory):
startup_complete = False
shutdown_complete = False

async def hello_world(request):
# modifications to the state should not leak across requests
assert request.state.count == 0
# modify the state, this should not leak to the lifespan or other requests
request.state.count += 1
# since state.list is a mutable object this modification _will_ leak across
# requests and to the lifespan
request.state.list.append(1)
return PlainTextResponse("hello, world")

async def run_startup(state):
nonlocal startup_complete
startup_complete = True
state["count"] = 0
state["list"] = []

async def run_shutdown(state):
nonlocal shutdown_complete
shutdown_complete = True
# modifications made to the state from a request do not leak to the lifespan
assert state["count"] == 0
# unless of course the request mutates a mutable object that is referenced
# via state
assert state["list"] == [1, 1]

app = Router(
on_startup=[run_startup],
on_shutdown=[run_shutdown],
routes=[Route("/", hello_world)],
)

assert not startup_complete
assert not shutdown_complete
with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
# Calling it a second time to ensure that the state is preserved.
client.get("/")
assert startup_complete
assert shutdown_complete


def test_lifespan_state_unsupported(test_client_factory):
@contextlib.asynccontextmanager
async def lifespan(app, scope):
yield None # pragma: no cover

app = Router(
lifespan=lifespan,
routes=[Mount("/", PlainTextResponse("hello, world"))],
)

async def no_state_wrapper(scope, receive, send):
del scope["state"]
await app(scope, receive, send)

with pytest.raises(
RuntimeError, match='The server does not support "state" in the lifespan scope'
):
with test_client_factory(no_state_wrapper):
raise AssertionError("Should not be called") # pragma: no cover


def test_lifespan_async_cm(test_client_factory):
startup_complete = False
shutdown_complete = False

async def hello_world(request):
# modifications to the state should not leak across requests
assert request.state.count == 0
# modify the state, this should not leak to the lifespan or other requests
request.state.count += 1
# since state.list is a mutable object this modification _will_ leak across
# requests and to the lifespan
request.state.list.append(1)
return PlainTextResponse("hello, world")

@contextlib.asynccontextmanager
async def lifespan(app: Starlette, state: typing.Dict[str, typing.Any]):
nonlocal startup_complete, shutdown_complete
startup_complete = True
state["count"] = 0
state["list"] = []
yield
shutdown_complete = True
# modifications made to the state from a request do not leak to the lifespan
assert state["count"] == 0
# unless of course the request mutates a mutable object that is referenced
# via state
assert state["list"] == [1, 1]

app = Router(
lifespan=lifespan,
routes=[Route("/", hello_world)],
)

assert not startup_complete
assert not shutdown_complete
with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
# Calling it a second time to ensure that the state is preserved.
client.get("/")
assert startup_complete
assert shutdown_complete


def test_raise_on_startup(test_client_factory):
def run_startup():
raise RuntimeError()
Expand Down

0 comments on commit 44159d0

Please sign in to comment.