Skip to content

Commit

Permalink
Merge pull request #3 from encode/request-auto-headers
Browse files Browse the repository at this point in the history
Request auto headers
  • Loading branch information
tomchristie committed Apr 16, 2019
2 parents 60befed + ee6f42a commit fae53e2
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 14 deletions.
2 changes: 1 addition & 1 deletion httpcore/compat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
try:
import brotli
except ImportError:
brotli = None
brotli = None # pragma: nocover
8 changes: 1 addition & 7 deletions httpcore/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,7 @@ async def open(
async def send(self, request: Request, stream: bool = False) -> Response:
method = request.method.encode()
target = request.url.target
host_header = (b"host", request.url.netloc.encode("ascii"))
if request.is_streaming:
content_length = (b"transfer-encoding", b"chunked")
else:
content_length = (b"content-length", str(len(request.body)).encode())

headers = [host_header, content_length] + request.headers
headers = request.headers

#  Start sending the request.
event = h11.Request(method=method, target=target, headers=headers)
Expand Down
44 changes: 40 additions & 4 deletions httpcore/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import typing
from urllib.parse import urlsplit

from .decoders import SUPPORTED_DECODERS, Decoder, IdentityDecoder, MultiDecoder
from .decoders import (
ACCEPT_ENCODING,
SUPPORTED_DECODERS,
Decoder,
IdentityDecoder,
MultiDecoder,
)
from .exceptions import ResponseClosed, StreamConsumed


Expand Down Expand Up @@ -59,20 +65,50 @@ class Request:
def __init__(
self,
method: str,
url: URL,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method
self.url = url
self.url = URL(url) if isinstance(url, str) else url
self.headers = list(headers)
if isinstance(body, bytes):
self.is_streaming = False
self.body = body
else:
self.is_streaming = True
self.body_aiter = body
self.headers = self._auto_headers() + self.headers

def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
has_host = False
has_content_length = False
has_accept_encoding = False

for header, value in self.headers:
header = header.strip().lower()
if header == b"host":
has_host = True
elif header in (b"content-length", b"transfer-encoding"):
has_content_length = True
elif header == b"accept-encoding":
has_accept_encoding = True

headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]

if not has_host:
headers.append((b"host", self.url.netloc.encode("ascii")))
if not has_content_length:
if self.is_streaming:
headers.append((b"transfer-encoding", b"chunked"))
elif self.body:
content_length = str(len(self.body)).encode()
headers.append((b"content-length", content_length))
if not has_accept_encoding:
headers.append((b"accept-encoding", ACCEPT_ENCODING))

return headers

async def stream(self) -> typing.AsyncIterator[bytes]:
assert self.is_streaming
Expand Down Expand Up @@ -131,7 +167,7 @@ async def read(self) -> bytes:
async def stream(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This will allow us to handle gzip, deflate, and brotli encoded responses.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "body"):
yield self.body
Expand Down
9 changes: 7 additions & 2 deletions httpcore/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,17 @@ def flush(self) -> bytes:


SUPPORTED_DECODERS = {
b"gzip": GZipDecoder,
b"deflate": DeflateDecoder,
b"identity": IdentityDecoder,
b"deflate": DeflateDecoder,
b"gzip": GZipDecoder,
b"br": BrotliDecoder,
}


if brotli is None:
SUPPORTED_DECODERS.pop(b"br") # pragma: nocover


ACCEPT_ENCODING = b", ".join(
[key for key in SUPPORTED_DECODERS.keys() if key != b"identity"]
)
92 changes: 92 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest

import httpcore


def test_host_header():
request = httpcore.Request("GET", "http://example.org")
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"deflate, gzip, br"),
]


def test_content_length_header():
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
assert request.headers == [
(b"host", b"example.org"),
(b"content-length", b"8"),
(b"accept-encoding", b"deflate, gzip, br"),
]


def test_transfer_encoding_header():
async def streaming_body(data):
yield data # pragma: nocover

body = streaming_body(b"test 123")

request = httpcore.Request("POST", "http://example.org", body=body)
assert request.headers == [
(b"host", b"example.org"),
(b"transfer-encoding", b"chunked"),
(b"accept-encoding", b"deflate, gzip, br"),
]


def test_override_host_header():
headers = [(b"host", b"1.2.3.4:80")]

request = httpcore.Request("GET", "http://example.org", headers=headers)
assert request.headers == [
(b"accept-encoding", b"deflate, gzip, br"),
(b"host", b"1.2.3.4:80"),
]


def test_override_accept_encoding_header():
headers = [(b"accept-encoding", b"identity")]

request = httpcore.Request("GET", "http://example.org", headers=headers)
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"identity"),
]


def test_override_content_length_header():
async def streaming_body(data):
yield data # pragma: nocover

body = streaming_body(b"test 123")
headers = [(b"content-length", b"8")]

request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"deflate, gzip, br"),
(b"content-length", b"8"),
]


def test_url():
request = httpcore.Request("GET", "http://example.org")
assert request.url.scheme == "http"
assert request.url.port == 80
assert request.url.target == "/"

request = httpcore.Request("GET", "https://example.org/abc?foo=bar")
assert request.url.scheme == "https"
assert request.url.port == 443
assert request.url.target == "/abc?foo=bar"


def test_invalid_urls():
with pytest.raises(ValueError):
httpcore.Request("GET", "example.org")

with pytest.raises(ValueError):
httpcore.Request("GET", "invalid://example.org")

with pytest.raises(ValueError):
httpcore.Request("GET", "http:///foo")

0 comments on commit fae53e2

Please sign in to comment.