Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse Request's body buffer for call_next in BaseHTTPMiddleware #1692

Merged
merged 10 commits into from
Jun 1, 2023
84 changes: 80 additions & 4 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

Expand All @@ -14,6 +14,81 @@
T = typing.TypeVar("T")


class _CachedRequest(Request):
Kludex marked this conversation as resolved.
Show resolved Hide resolved
"""
If the user calls Request.body() from their dispatch function
we cache the entire request body in memory and pass that to downstream middlewares,
but if they call Request.stream() then all we do is send an
empty body so that downstream things don't hang forever.
"""

def __init__(self, scope: Scope, receive: Receive):
super().__init__(scope, receive)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False
self._wrapped_rc_stream = self.stream()

async def wrapped_receive(self) -> Message:
# wrapped_rcv state 1: disconnected
if self._wrapped_rcv_disconnected:
# we've already sent a disconnect to the downstream app
# we don't need to wait to get another one
# (although most ASGI servers will just keep sending it)
return {"type": "http.disconnect"}
# wrapped_rcv state 1: consumed but not yet disconnected
if self._wrapped_rcv_consumed:
# since the downstream app has consumed us all that is left
# is to send it a disconnect
if self._is_disconnected:
# the middleware has already seen the disconnect
# since we know the client is disconnected no need to wait
# for the message
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# we don't know yet if the client is disconnected or not
# so we'll wait until we get that message
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
return msg

# wrapped_rcv state 3: not yet consumed
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion
# return an empty body so that downstream apps don't hang
# waiting for a disconnect
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
try:
stream = self.stream()
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": not self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}


class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
Expand All @@ -26,6 +101,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request = _CachedRequest(scope, receive)
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be here, does it? Can it be in the same place it was instantiated before, or am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next line references this. It, maybe, can be moved to where it was before but at the very least it will need a new variable name likeouter_request to differentiate it from the request: Request on line 95. It makes more sense to just move it up here, there is no harm in that.

wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
Expand All @@ -44,7 +121,7 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
message = await wrap(wrapped_receive)

if response_sent.is_set():
return {"type": "http.disconnect"}
Expand Down Expand Up @@ -104,9 +181,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
return response

async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(
Expand Down
7 changes: 3 additions & 4 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,14 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
return
if self._stream_consumed:
raise RuntimeError("Stream consumed")
self._stream_consumed = True
while True:
while not self._stream_consumed:
message = await self._receive()
if message["type"] == "http.request":
body = message.get("body", b"")
if not message.get("more_body", False):
self._stream_consumed = True
adriangb marked this conversation as resolved.
Show resolved Hide resolved
if body:
yield body
if not message.get("more_body", False):
break
elif message["type"] == "http.disconnect":
self._is_disconnected = True
raise ClientDisconnect()
Expand Down