Skip to content

Commit

Permalink
Re-organize, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed May 4, 2023
1 parent a3ba02c commit 96d5b49
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 43 deletions.
98 changes: 55 additions & 43 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,64 @@ def __init__(self, scope: Scope, receive: Receive):
self._wrapped_rcv_consumed = False

async def wrapped_receive(self) -> Message:
wrapped_rcv_connected = not (
self._wrapped_rcv_disconnected or self._wrapped_rcv_consumed
)
if wrapped_rcv_connected:
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion or client disconnected
self._wrapped_rcv_consumed = True
# wrapped_rcv state 1: disconnected
if self._wrapped_rcv_disconnected:
# we've already sent a disconnect to the downstream app
# we don't need to wait to get another one
# (although most ASGI servers will just keep sending it)
return {"type": "http.disconnect"}
# wrapped_rcv state 1: consumed but not yet disconnected
if self._wrapped_rcv_consumed:
# since the downstream app has consumed us all that is left
# is to send it a disconnect
if self._is_disconnected:
# the middleware has already seen the disconnect
# since we know the client is disconnected no need to wait
# for the message
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# we don't know yet if the client is disconnected or not
# so we'll wait until we get that message
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
return msg

# wrapped_rcv state 3: not yet consumed
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion
# return an empty body so that downstream apps don't hang
# waiting for a disconnect
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
stream = self.stream()
try:
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": b"",
"more_body": False,
"body": chunk,
"more_body": self._stream_consumed,
}
else:
# body() was never called and stream() wasn't consumed
stream = self.stream()
try:
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# wrapped_rcv is either disconnected or consumed
if self._is_disconnected:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# if we haven't received a disconnect yet we wait for it
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
raise RuntimeError(f"Unexpected message received: {msg['type']}")
# mark ourselves and upstream as disconnected
self._is_disconnected = True
self._wrapped_rcv_disconnected = True
return msg
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}


class BaseHTTPMiddleware:
Expand Down
38 changes: 38 additions & 0 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,44 @@ async def send(msg: Message):
await rcv.aclose()


@pytest.mark.anyio
async def test_read_request_disconnected_after_consuming_steam() -> None:
async def endpoint(scope: Scope, receive: Receive, send: Send) -> None:
msg = await receive()
assert msg.pop("more_body", False) is False
assert msg == {"type": "http.request", "body": b"hi"}
msg = await receive()
assert msg == {"type": "http.disconnect"}
await Response()(scope, receive, send)

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
await request.body()
disconnected = await request.is_disconnected()
assert disconnected is True
response = await call_next(request)
return response

scope = {"type": "http", "method": "POST", "path": "/"}

async def receive() -> AsyncGenerator[Message, None]:
yield {"type": "http.request", "body": b"hi"}
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover

async def send(msg: Message):
if msg["type"] == "http.response.start":
assert msg["status"] == 200

app: ASGIApp = ConsumingMiddleware(endpoint)

rcv = receive()

await app(scope, rcv.__anext__, send)

await rcv.aclose()


def test_downstream_middleware_modifies_receive(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
Expand Down

0 comments on commit 96d5b49

Please sign in to comment.