From f6d0ae35abf4083c5fcec6d10da605514aa96831 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Fri, 17 Feb 2023 10:17:19 -0600 Subject: [PATCH] Allow str/bytes subclasses to be used as header parts --- requests/_internal_utils.py | 6 ++++-- requests/utils.py | 30 +++++++++++++++++++----------- tests/test_requests.py | 25 +++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/requests/_internal_utils.py b/requests/_internal_utils.py index 7dc9bc5336..f2cf635e29 100644 --- a/requests/_internal_utils.py +++ b/requests/_internal_utils.py @@ -14,9 +14,11 @@ _VALID_HEADER_VALUE_RE_BYTE = re.compile(rb"^\S[^\r\n]*$|^$") _VALID_HEADER_VALUE_RE_STR = re.compile(r"^\S[^\r\n]*$|^$") +_HEADER_VALIDATORS_STR = (_VALID_HEADER_NAME_RE_STR, _VALID_HEADER_VALUE_RE_STR) +_HEADER_VALIDATORS_BYTE = (_VALID_HEADER_NAME_RE_BYTE, _VALID_HEADER_VALUE_RE_BYTE) HEADER_VALIDATORS = { - bytes: (_VALID_HEADER_NAME_RE_BYTE, _VALID_HEADER_VALUE_RE_BYTE), - str: (_VALID_HEADER_NAME_RE_STR, _VALID_HEADER_VALUE_RE_STR), + bytes: _HEADER_VALIDATORS_BYTE, + str: _HEADER_VALIDATORS_STR, } diff --git a/requests/utils.py b/requests/utils.py index ad5358381a..a367417f8e 100644 --- a/requests/utils.py +++ b/requests/utils.py @@ -25,7 +25,12 @@ from .__version__ import __version__ # to_native_string is unused here, but imported here for backwards compatibility -from ._internal_utils import HEADER_VALIDATORS, to_native_string # noqa: F401 +from ._internal_utils import ( # noqa: F401 + _HEADER_VALIDATORS_BYTE, + _HEADER_VALIDATORS_STR, + HEADER_VALIDATORS, + to_native_string, +) from .compat import ( Mapping, basestring, @@ -1031,20 +1036,23 @@ def check_header_validity(header): :param header: tuple, in the format (name, value). """ name, value = header + _validate_header_part(header, name, 0) + _validate_header_part(header, value, 1) - for part in header: - if type(part) not in HEADER_VALIDATORS: - raise InvalidHeader( - f"Header part ({part!r}) from {{{name!r}: {value!r}}} must be " - f"of type str or bytes, not {type(part)}" - ) - - _validate_header_part(name, "name", HEADER_VALIDATORS[type(name)][0]) - _validate_header_part(value, "value", HEADER_VALIDATORS[type(value)][1]) +def _validate_header_part(header, header_part, header_validator_index): + if isinstance(header_part, str): + validator = _HEADER_VALIDATORS_STR[header_validator_index] + elif isinstance(header_part, bytes): + validator = _HEADER_VALIDATORS_BYTE[header_validator_index] + else: + raise InvalidHeader( + f"Header part ({header_part!r}) from {header} " + f"must be of type str or bytes, not {type(header_part)}" + ) -def _validate_header_part(header_part, header_kind, validator): if not validator.match(header_part): + header_kind = "name" if header_validator_index == 0 else "value" raise InvalidHeader( f"Invalid leading whitespace, reserved character(s), or return" f"character(s) in header {header_kind}: {header_part!r}" diff --git a/tests/test_requests.py b/tests/test_requests.py index 2a5c9b1102..b1c8dd4534 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1752,6 +1752,31 @@ def test_header_no_leading_space(self, httpbin, invalid_header): with pytest.raises(InvalidHeader): requests.get(httpbin("get"), headers=invalid_header) + def test_header_with_subclass_types(self, httpbin): + """If the subclasses does not behave *exactly* like + the base bytes/str classes, this is not supported. + This test is for backwards compatibility. + """ + + class MyString(str): + pass + + class MyBytes(bytes): + pass + + r_str = requests.get(httpbin("get"), headers={MyString("x-custom"): "myheader"}) + assert r_str.request.headers["x-custom"] == "myheader" + + r_bytes = requests.get( + httpbin("get"), headers={MyBytes(b"x-custom"): b"myheader"} + ) + assert r_bytes.request.headers["x-custom"] == b"myheader" + + r_mixed = requests.get( + httpbin("get"), headers={MyString("x-custom"): MyBytes(b"myheader")} + ) + assert r_mixed.request.headers["x-custom"] == b"myheader" + @pytest.mark.parametrize("files", ("foo", b"foo", bytearray(b"foo"))) def test_can_send_objects_with_files(self, httpbin, files): data = {"a": "this is a string"}