Skip to content

Commit

Permalink
Used cached request body for downstream ASGI app in BaseHTTPMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Feb 13, 2023
1 parent 5771a78 commit 2048c44
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 8 deletions.
64 changes: 60 additions & 4 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

Expand All @@ -14,6 +14,61 @@
T = typing.TypeVar("T")


class _CachedRequest(Request):
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
super().__init__(*args, **kwargs)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False

async def wrapped_receive(self) -> Message:
wrapped_rcv_connected = not (
self._wrapped_rcv_disconnected or self._wrapped_rcv_consumed
)
if wrapped_rcv_connected:
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion or client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
stream = self.stream()
try:
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# wrapped_rcv is either disconnected or consumed
if self._is_disconnected:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# if we haven't received a disconnect yet we wait for it
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
raise RuntimeError(f"Unexpected message received: {msg['type']}")
# mark ourselves and upstream as disconnected
self._is_disconnected = True
self._wrapped_rcv_disconnected = True
return msg


class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
Expand All @@ -26,6 +81,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
Expand All @@ -44,7 +101,7 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
message = await wrap(wrapped_receive)

if response_sent.is_set():
return {"type": "http.disconnect"}
Expand Down Expand Up @@ -104,9 +161,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
return response

async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(
Expand Down
4 changes: 3 additions & 1 deletion starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,12 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
while True:
message = await self._receive()
if message["type"] == "http.request":
if not message.get("more_body", False):
self._stream_consumed = True
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
if self._stream_consumed:
break
elif message["type"] == "http.disconnect":
self._is_disconnected = True
Expand Down
239 changes: 236 additions & 3 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import contextvars
from contextlib import AsyncExitStack
from typing import AsyncGenerator, Callable

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -413,3 +416,233 @@ async def downstream_app(scope, receive, send):
client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_read_request_stream_in_app_after_middleware_calls_stream(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def homepage(request: Request):
expected = [b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return PlainTextResponse("Homepage")

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return await call_next(request)

app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_app_after_middleware_calls_body(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def homepage(request: Request):
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return PlainTextResponse("Homepage")

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
assert await request.body() == b"a"
return await call_next(request)

app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_body_in_app_after_middleware_calls_stream(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def homepage(request: Request):
assert await request.body() == b""
return PlainTextResponse("Homepage")

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return await call_next(request)

app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_body_in_app_after_middleware_calls_body(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def homepage(request: Request):
assert await request.body() == b"a"
return PlainTextResponse("Homepage")

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
assert await request.body() == b"a"
return await call_next(request)

app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_dispatch_after_app_calls_stream(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def homepage(request: Request):
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return PlainTextResponse("Homepage")

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
resp = await call_next(request)
with pytest.raises(RuntimeError, match="Stream consumed"):
async for _ in request.stream():
raise AssertionError("should not be called") # pragma: no cover
return resp

app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


def test_read_request_stream_in_dispatch_after_app_calls_body(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def homepage(request: Request):
assert await request.body() == b"a"
return PlainTextResponse("Homepage")

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
resp = await call_next(request)
with pytest.raises(RuntimeError, match="Stream consumed"):
async for _ in request.stream():
raise AssertionError("should not be called") # pragma: no cover
return resp

app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)

client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200


@pytest.mark.anyio
async def test_read_request_disconnected_client() -> None:
"""If we receive a disconnect message when the downstream ASGI
app calls receive() the Request instance passed into the dispatch function
should get marked as disconnected.
The downstream ASGI app should not get a ClientDisconnect raised,
instead if should just receive the disconnect message.
"""

async def endpoint(scope: Scope, receive: Receive, send: Send) -> None:
msg = await receive()
assert msg["type"] == "http.disconnect"
await Response()(scope, receive, send)

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
response = await call_next(request)
disconnected = await request.is_disconnected()
assert disconnected is True
return response

scope = {"type": "http", "method": "POST", "path": "/"}

async def receive() -> AsyncGenerator[Message, None]:
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover

async def send(msg: Message):
if msg["type"] == "http.response.start":
assert msg["status"] == 200

app: ASGIApp = ConsumingMiddleware(endpoint)

rcv = receive()

await app(scope, rcv.__anext__, send)

await rcv.aclose()


def test_downstream_middleware_modifies_receive(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
"""If a downstream middleware modifies receive() the final ASGI app
should see the modified version.
"""

async def endpoint(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
assert body == b"foo foo "
await Response()(scope, receive, send)

class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
body = await request.body()
assert body == b"foo "
return await call_next(request)

def modifying_middleware(app: ASGIApp) -> ASGIApp:
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
async def wrapped_receive() -> Message:
msg = await receive()
if msg["type"] == "http.request":
msg["body"] = msg["body"] * 2
return msg

await app(scope, wrapped_receive, send)

return wrapped_app

client = test_client_factory(ConsumingMiddleware(modifying_middleware(endpoint)))

resp = client.post("/", content=b"foo ")
assert resp.status_code == 200

0 comments on commit 2048c44

Please sign in to comment.