Skip to content

Commit

Permalink
Revert "Support lifespan state (#2060)"
Browse files Browse the repository at this point in the history
This reverts commit da6461b.
  • Loading branch information
adriangb authored and Kludex committed Mar 9, 2023
1 parent 92ab71e commit 2d34ccc
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 81 deletions.
6 changes: 4 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send


class Starlette:
Expand Down Expand Up @@ -55,7 +55,9 @@ def __init__(
] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[Lifespan] = None,
lifespan: typing.Optional[
typing.Callable[["Starlette"], typing.AsyncContextManager]
] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand Down
15 changes: 14 additions & 1 deletion starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
<<<<<<< HEAD
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
=======
from starlette.types import ASGIApp, Receive, Scope, Send
>>>>>>> 44159d0 (Revert "Support lifespan state (#2060)")
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -580,7 +584,9 @@ def __init__(
default: typing.Optional[ASGIApp] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[Lifespan] = None,
lifespan: typing.Optional[
typing.Callable[[typing.Any], typing.AsyncContextManager]
] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
Expand Down Expand Up @@ -661,6 +667,7 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
startup and shutdown events.
"""
started = False
<<<<<<< HEAD
app: typing.Any = scope.get("app")
await receive()
try:
Expand All @@ -671,6 +678,12 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
'The server does not support "state" in the lifespan scope.'
)
scope["state"].update(maybe_state)
=======
app = scope.get("app")
await receive()
try:
async with self.lifespan_context(app):
>>>>>>> 44159d0 (Revert "Support lifespan state (#2060)")
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
Expand Down
9 changes: 1 addition & 8 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,11 @@ def __init__(
portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
*,
app_state: typing.Dict[str, typing.Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state

def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
Expand Down Expand Up @@ -246,7 +243,6 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"client": ["testclient", 50000],
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
Expand All @@ -264,7 +260,6 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
}

request_complete = False
Expand Down Expand Up @@ -385,13 +380,11 @@ def __init__(
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
self.app_state: typing.Dict[str, typing.Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
)
if headers is None:
headers = {}
Expand Down Expand Up @@ -756,7 +749,7 @@ def __exit__(self, *args: typing.Any) -> None:
self.exit_stack.close()

async def lifespan(self) -> None:
scope = {"type": "lifespan", "state": self.app_state}
scope = {"type": "lifespan"}
try:
await self.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
Expand Down
70 changes: 0 additions & 70 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import functools
import sys
import typing
Expand Down Expand Up @@ -676,75 +675,6 @@ def run_shutdown():
assert shutdown_complete


def test_lifespan_state_unsupported(test_client_factory):
@contextlib.asynccontextmanager
async def lifespan(app):
yield {"foo": "bar"}

app = Router(
lifespan=lifespan,
routes=[Mount("/", PlainTextResponse("hello, world"))],
)

async def no_state_wrapper(scope, receive, send):
del scope["state"]
await app(scope, receive, send)

with pytest.raises(
RuntimeError, match='The server does not support "state" in the lifespan scope'
):
with test_client_factory(no_state_wrapper):
raise AssertionError("Should not be called") # pragma: no cover


def test_lifespan_state_async_cm(test_client_factory):
startup_complete = False
shutdown_complete = False

class State(TypedDict):
count: int
items: typing.List[int]

async def hello_world(request: Request) -> Response:
# modifications to the state should not leak across requests
assert request.state.count == 0
# modify the state, this should not leak to the lifespan or other requests
request.state.count += 1
# since state.items is a mutable object this modification _will_ leak across
# requests and to the lifespan
request.state.items.append(1)
return PlainTextResponse("hello, world")

@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> typing.AsyncIterator[State]:
nonlocal startup_complete, shutdown_complete
startup_complete = True
state = State(count=0, items=[])
yield state
shutdown_complete = True
# modifications made to the state from a request do not leak to the lifespan
assert state["count"] == 0
# unless of course the request mutates a mutable object that is referenced
# via state
assert state["items"] == [1, 1]

app = Router(
lifespan=lifespan,
routes=[Route("/", hello_world)],
)

assert not startup_complete
assert not shutdown_complete
with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
# Calling it a second time to ensure that the state is preserved.
client.get("/")
assert startup_complete
assert shutdown_complete


def test_raise_on_startup(test_client_factory):
def run_startup():
raise RuntimeError()
Expand Down

0 comments on commit 2d34ccc

Please sign in to comment.