Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow lax response parsing on Py parser (#7663) #7664

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/7663.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updated Python parser to comply with latest HTTP specs and allow lax response parsing -- by :user:`Dreamorcerer`
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ aiohttp/_find_header.c: $(call to-hash,aiohttp/hdrs.py ./tools/gen.py)

# _find_headers generator creates _headers.pyi as well
aiohttp/%.c: aiohttp/%.pyx $(call to-hash,$(CYS)) aiohttp/_find_header.c
cython -3 -o $@ $< -I aiohttp
cython -3 -o $@ $< -I aiohttp -Werror

vendor/llhttp/node_modules: vendor/llhttp/package.json
cd vendor/llhttp; npm install
Expand Down
76 changes: 59 additions & 17 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import collections
import re
import string
import sys
import zlib
from contextlib import suppress
from enum import IntEnum
from typing import (
Any,
ClassVar,
Generic,
List,
NamedTuple,
Expand All @@ -26,7 +28,7 @@

from . import hdrs
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS, BaseTimerContext
from .helpers import DEBUG, NO_EXTENSIONS, BaseTimerContext
from .http_exceptions import (
BadHttpMessage,
BadStatusLine,
Expand All @@ -41,6 +43,11 @@
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import Final, RawHeaders

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

try:
import brotli

Expand All @@ -58,6 +65,8 @@
"RawResponseMessage",
)

_SEP = Literal[b"\r\n", b"\n"]

ASCIISET: Final[Set[str]] = set(string.printable)

# See https://www.rfc-editor.org/rfc/rfc9110.html#name-overview
Expand All @@ -70,6 +79,7 @@
METHRE: Final[Pattern[str]] = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d).(\d)")
HDRRE: Final[Pattern[bytes]] = re.compile(rb"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\"\\]")
HEXDIGIT = re.compile(rb"[0-9a-fA-F]+")


class RawRequestMessage(NamedTuple):
Expand Down Expand Up @@ -173,7 +183,8 @@ def parse_headers(
# consume continuation lines
continuation = line and line[0] in (32, 9) # (' ', '\t')

# Deprecated: https://www.rfc-editor.org/rfc/rfc9112.html#name-obsolete-line-folding
# Deprecated:
# https://www.rfc-editor.org/rfc/rfc9112.html#name-obsolete-line-folding
if continuation:
bvalue_lst = [bvalue]
while continuation:
Expand Down Expand Up @@ -223,6 +234,8 @@ def parse_headers(


class HttpParser(abc.ABC, Generic[_MsgT]):
lax: ClassVar[bool] = False

def __init__(
self,
protocol: Optional[BaseProtocol] = None,
Expand Down Expand Up @@ -285,7 +298,7 @@ def feed_eof(self) -> Optional[_MsgT]:
def feed_data(
self,
data: bytes,
SEP: bytes = b"\r\n",
SEP: _SEP = b"\r\n",
EMPTY: bytes = b"",
CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH,
METH_CONNECT: str = hdrs.METH_CONNECT,
Expand All @@ -309,13 +322,16 @@ def feed_data(
pos = data.find(SEP, start_pos)
# consume \r\n
if pos == start_pos and not self._lines:
start_pos = pos + 2
start_pos = pos + len(SEP)
continue

if pos >= start_pos:
# line found
self._lines.append(data[start_pos:pos])
start_pos = pos + 2
line = data[start_pos:pos]
if SEP == b"\n": # For lax response parsing
line = line.rstrip(b"\r")
self._lines.append(line)
start_pos = pos + len(SEP)

# \r\n\r\n found
if self._lines[-1] == EMPTY:
Expand All @@ -332,7 +348,7 @@ def get_content_length() -> Optional[int]:

# Shouldn't allow +/- or other number formats.
# https://www.rfc-editor.org/rfc/rfc9110#section-8.6-2
if not length_hdr.strip(" \t").isdigit():
if not length_hdr.strip(" \t").isdecimal():
raise InvalidHeader(CONTENT_LENGTH)

return int(length_hdr)
Expand Down Expand Up @@ -369,6 +385,7 @@ def get_content_length() -> Optional[int]:
readall=self.readall,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
Expand All @@ -387,6 +404,7 @@ def get_content_length() -> Optional[int]:
compression=msg.compression,
readall=True,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
else:
if (
Expand All @@ -410,6 +428,7 @@ def get_content_length() -> Optional[int]:
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
Expand All @@ -432,7 +451,7 @@ def get_content_length() -> Optional[int]:
assert not self._lines
assert self._payload_parser is not None
try:
eof, data = self._payload_parser.feed_data(data[start_pos:])
eof, data = self._payload_parser.feed_data(data[start_pos:], SEP)
except BaseException as exc:
if self.payload_exception is not None:
self._payload_parser.payload.set_exception(
Expand Down Expand Up @@ -627,6 +646,20 @@ class HttpResponseParser(HttpParser[RawResponseMessage]):
Returns RawResponseMessage.
"""

# Lax mode should only be enabled on response parser.
lax = not DEBUG

def feed_data(
self,
data: bytes,
SEP: Optional[_SEP] = None,
*args: Any,
**kwargs: Any,
) -> Tuple[List[Tuple[RawResponseMessage, StreamReader]], bool, bytes]:
if SEP is None:
SEP = b"\r\n" if DEBUG else b"\n"
return super().feed_data(data, SEP, *args, **kwargs)

def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
line = lines[0].decode("utf-8", "surrogateescape")
try:
Expand All @@ -651,7 +684,7 @@ def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
version_o = HttpVersion(int(match.group(1)), int(match.group(2)))

# The status code is a three-digit number
if len(status) != 3 or not status.isdigit():
if len(status) != 3 or not status.isdecimal():
raise BadStatusLine(line)
status_i = int(status)

Expand Down Expand Up @@ -693,13 +726,15 @@ def __init__(
readall: bool = False,
response_with_body: bool = True,
auto_decompress: bool = True,
lax: bool = False,
) -> None:
self._length = 0
self._type = ParseState.PARSE_NONE
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
self._chunk_size = 0
self._chunk_tail = b""
self._auto_decompress = auto_decompress
self._lax = lax
self.done = False

# payload decompression wrapper
Expand Down Expand Up @@ -751,7 +786,7 @@ def feed_eof(self) -> None:
)

def feed_data(
self, chunk: bytes, SEP: bytes = b"\r\n", CHUNK_EXT: bytes = b";"
self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";"
) -> Tuple[bool, bytes]:
# Read specified amount of bytes
if self._type == ParseState.PARSE_LENGTH:
Expand Down Expand Up @@ -788,17 +823,22 @@ def feed_data(
else:
size_b = chunk[:pos]

if not size_b.isdigit():
if self._lax: # Allow whitespace in lax mode.
size_b = size_b.strip()

if not re.fullmatch(HEXDIGIT, size_b):
exc = TransferEncodingError(
chunk[:pos].decode("ascii", "surrogateescape")
)
self.payload.set_exception(exc)
raise exc
size = int(bytes(size_b), 16)

chunk = chunk[pos + 2 :]
chunk = chunk[pos + len(SEP) :]
if size == 0: # eof marker
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
if self._lax and chunk.startswith(b"\r"):
chunk = chunk[1:]
else:
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK
self._chunk_size = size
Expand All @@ -820,13 +860,15 @@ def feed_data(
self._chunk_size = 0
self.payload.feed_data(chunk[:required], required)
chunk = chunk[required:]
if self._lax and chunk.startswith(b"\r"):
chunk = chunk[1:]
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF
self.payload.end_http_chunk_receiving()

# toss the CRLF at the end of the chunk
if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF:
if chunk[:2] == SEP:
chunk = chunk[2:]
if chunk[: len(SEP)] == SEP:
chunk = chunk[len(SEP) :]
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
else:
self._chunk_tail = chunk
Expand All @@ -836,11 +878,11 @@ def feed_data(
# we should get another \r\n otherwise
# trailers needs to be skiped until \r\n\r\n
if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS:
head = chunk[:2]
head = chunk[: len(SEP)]
if head == SEP:
# end of stream
self.payload.feed_eof()
return True, chunk[2:]
return True, chunk[len(SEP) :]
# Both CR and LF, or only LF may not be received yet. It is
# expected that CRLF or LF will be shown at the very first
# byte next time, otherwise trailers should come. The last
Expand All @@ -858,7 +900,7 @@ def feed_data(
if self._chunk == ChunkState.PARSE_TRAILERS:
pos = chunk.find(SEP)
if pos >= 0:
chunk = chunk[pos + 2 :]
chunk = chunk[pos + len(SEP) :]
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
else:
self._chunk_tail = chunk
Expand Down