Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse Request's body buffer for call_next in BaseHTTPMiddleware #1692

Merged
merged 10 commits into from
Jun 1, 2023
71 changes: 67 additions & 4 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

Expand All @@ -14,6 +14,68 @@
T = typing.TypeVar("T")


class _CachedRequest(Request):
Kludex marked this conversation as resolved.
Show resolved Hide resolved
"""
If the user calls Request.body() from their dispatch function
we cache the entire request body in memory and pass that to downstream middlewares,
but if they call Request.stream() then all we do is send an
empty body so that downstream things don't hang forever.
"""

def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
adriangb marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False

async def wrapped_receive(self) -> Message:
wrapped_rcv_connected = not (
self._wrapped_rcv_disconnected or self._wrapped_rcv_consumed
)
if wrapped_rcv_connected:
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion or client disconnected
adriangb marked this conversation as resolved.
Show resolved Hide resolved
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
stream = self.stream()
try:
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# wrapped_rcv is either disconnected or consumed
adriangb marked this conversation as resolved.
Show resolved Hide resolved
if self._is_disconnected:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# if we haven't received a disconnect yet we wait for it
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
raise RuntimeError(f"Unexpected message received: {msg['type']}")
# mark ourselves and upstream as disconnected
self._is_disconnected = True
self._wrapped_rcv_disconnected = True
return msg


class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
Expand All @@ -26,6 +88,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request = _CachedRequest(scope, receive)
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be here, does it? Can it be in the same place it was instantiated before, or am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next line references this. It, maybe, can be moved to where it was before but at the very least it will need a new variable name likeouter_request to differentiate it from the request: Request on line 95. It makes more sense to just move it up here, there is no harm in that.

wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
Expand All @@ -44,7 +108,7 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
message = await wrap(wrapped_receive)

if response_sent.is_set():
return {"type": "http.disconnect"}
Expand Down Expand Up @@ -104,9 +168,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
return response

async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(
Expand Down
2 changes: 1 addition & 1 deletion starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
if self._stream_consumed:
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change is needed? The _CachedRequest doesn't change the value of self._stream_consumed. 🤔

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test test_read_request_stream_in_dispatch_after_app_calls_body fails without this logic.

Hmm... Why the more_body doesn't matter? Like, not considering the BaseHTTPMiddleware, why the more_body doesn't matter to exit?

Hmmm... If we receive 2 chunks of body, how this works? It doesn't look like we have a test that covers standalone Request with multiple chunks. 🤔

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really should have tests for Request as a standalone thing since it is a standalone thing in the public API and... I've been encouraging people to use it e.g. in ASGI middleware.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, we just don't cover what I mention

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add it. If you already prototyped it out in your head or on paper please comment it here and save me a few min haha.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall how to do it from the TestClient's POV, but I thought about sending a stream with 2 chunks. Maybe you can use httpx directly if you can't do it with the TestClient.

I guess that would be enough to break this logic here, since the value of stream_consumed will not change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok yes, you were right, I did have a bug, good catch. I still need to modify Request a bit, I added a couple of tests to explain why. TLDR is we were marking the stream as consumed as soon as you call stream() but in reality you can call stream, get one message and then call steam again before it is consumed. Let me know if it's clear now.

break
elif message["type"] == "http.disconnect":
self._is_disconnected = True
Expand Down