Skip to content

Commit

Permalink
Used cached request body for downstream ASGI app in BaseHTTPMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Apr 29, 2023
1 parent ac469df commit 6ff01c0
Show file tree
Hide file tree
Showing 3 changed files with 360 additions and 8 deletions.
71 changes: 67 additions & 4 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

Expand All @@ -14,6 +14,68 @@
T = typing.TypeVar("T")


class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
we cache the entire request body in memory and pass that to downstream middlewares,
but if they call Request.stream() then all we do is send an
empty body so that downstream things don't hang forever.
"""

def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
super().__init__(*args, **kwargs)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False

async def wrapped_receive(self) -> Message:
wrapped_rcv_connected = not (
self._wrapped_rcv_disconnected or self._wrapped_rcv_consumed
)
if wrapped_rcv_connected:
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion or client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
stream = self.stream()
try:
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# wrapped_rcv is either disconnected or consumed
if self._is_disconnected:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# if we haven't received a disconnect yet we wait for it
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
raise RuntimeError(f"Unexpected message received: {msg['type']}")
# mark ourselves and upstream as disconnected
self._is_disconnected = True
self._wrapped_rcv_disconnected = True
return msg


class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
Expand All @@ -26,6 +88,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
Expand All @@ -44,7 +108,7 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
message = await wrap(wrapped_receive)

if response_sent.is_set():
return {"type": "http.disconnect"}
Expand Down Expand Up @@ -104,9 +168,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
return response

async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(
Expand Down
4 changes: 3 additions & 1 deletion starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,12 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
while True:
message = await self._receive()
if message["type"] == "http.request":
if not message.get("more_body", False):
self._stream_consumed = True
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
if self._stream_consumed:
break
elif message["type"] == "http.disconnect":
self._is_disconnected = True
Expand Down

0 comments on commit 6ff01c0

Please sign in to comment.