Skip to content

Commit

Permalink
Reuse Request's body buffer for call_next in BaseHTTPMiddleware (#1692)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Jun 1, 2023
1 parent bdabbf7 commit 554b9e2
Show file tree
Hide file tree
Showing 4 changed files with 573 additions and 13 deletions.
84 changes: 80 additions & 4 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

Expand All @@ -14,6 +14,81 @@
T = typing.TypeVar("T")


class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
we cache the entire request body in memory and pass that to downstream middlewares,
but if they call Request.stream() then all we do is send an
empty body so that downstream things don't hang forever.
"""

def __init__(self, scope: Scope, receive: Receive):
super().__init__(scope, receive)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False
self._wrapped_rc_stream = self.stream()

async def wrapped_receive(self) -> Message:
# wrapped_rcv state 1: disconnected
if self._wrapped_rcv_disconnected:
# we've already sent a disconnect to the downstream app
# we don't need to wait to get another one
# (although most ASGI servers will just keep sending it)
return {"type": "http.disconnect"}
# wrapped_rcv state 1: consumed but not yet disconnected
if self._wrapped_rcv_consumed:
# since the downstream app has consumed us all that is left
# is to send it a disconnect
if self._is_disconnected:
# the middleware has already seen the disconnect
# since we know the client is disconnected no need to wait
# for the message
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# we don't know yet if the client is disconnected or not
# so we'll wait until we get that message
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
return msg

# wrapped_rcv state 3: not yet consumed
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion
# return an empty body so that downstream apps don't hang
# waiting for a disconnect
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
try:
stream = self.stream()
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": not self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}


class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
Expand All @@ -26,6 +101,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
Expand All @@ -44,7 +121,7 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
message = await wrap(wrapped_receive)

if response_sent.is_set():
return {"type": "http.disconnect"}
Expand Down Expand Up @@ -104,9 +181,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
return response

async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(
Expand Down
7 changes: 3 additions & 4 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,14 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
return
if self._stream_consumed:
raise RuntimeError("Stream consumed")
self._stream_consumed = True
while True:
while not self._stream_consumed:
message = await self._receive()
if message["type"] == "http.request":
body = message.get("body", b"")
if not message.get("more_body", False):
self._stream_consumed = True
if body:
yield body
if not message.get("more_body", False):
break
elif message["type"] == "http.disconnect":
self._is_disconnected = True
raise ClientDisconnect()
Expand Down

0 comments on commit 554b9e2

Please sign in to comment.