Skip to content

Commit

Permalink
Add default encoding to client session.
Browse files Browse the repository at this point in the history
  • Loading branch information
john-parton committed Aug 28, 2023
1 parent 44056de commit df4fe75
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
4 changes: 4 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class ClientSession:
"_read_bufsize",
"_max_line_size",
"_max_field_size",
"_default_encoding",
)

def __init__(
Expand Down Expand Up @@ -221,6 +222,7 @@ def __init__(
read_bufsize: int = 2**16,
max_line_size: int = 8190,
max_field_size: int = 8190,
default_encoding: Union[str, Callable[[ClientResponse], str], None],
) -> None:
if base_url is None or isinstance(base_url, URL):
self._base_url: Optional[URL] = base_url
Expand Down Expand Up @@ -291,6 +293,8 @@ def __init__(
for trace_config in self._trace_configs:
trace_config.freeze()

self._default_encoding = default_encoding

def __init_subclass__(cls: Type["ClientSession"]) -> None:
raise TypeError(
"Inheritance class {} from ClientSession "
Expand Down
43 changes: 30 additions & 13 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,18 +1013,32 @@ async def read(self) -> bytes:

return self._body

def _guess_encoding(self) -> str:
warnings.warn(
"Automatic guessing of encodings is scheduled for removal in 5.0",
DeprecationWarning,
stacklevel=2,
)
def _get_default_encoding(self) -> str:
_default_encoding = self._session._default_encoding

if isinstance(_default_encoding, str):
return _default_encoding

if callable(_default_encoding):
return _default_encoding(self)

if _default_encoding is None:
warnings.warn(
"Automatic guessing of encodings is scheduled for removal in 5.0. "
"Pass `default_encoding` to ClientSession instead.",
DeprecationWarning,
stacklevel=2,
)

# TLD parsing should match internal logic for matching charsets in chardetng
tld = self.url.host.split(".")[-1]
# We do not allow utf-8 here because in ordinary operations, utf-8 encoding
# is always tried first
return detect_charset(self._body, tld=tld, allow_utf8=False)
# TLD parsing should match internal logic for matching charsets in chardetng
tld = self.url.host.split(".")[-1]
# We do not allow utf-8 here because in ordinary operations, utf-8 encoding
# is always tried first
return detect_charset(self._body, tld=tld, allow_utf8=False)

raise TypeError(
f"Invalid type for default_encoding: {_default_encoding.__class__.__name__}"
)

def get_encoding(self) -> str:
ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
Expand Down Expand Up @@ -1054,9 +1068,12 @@ async def text(self, encoding: Optional[str] = None, errors: str = "strict") ->
try:
return self._body.decode(self.get_encoding(), errors=errors) # type: ignore[union-attr]
except UnicodeDecodeError:
pass
encoding = self._get_default_encoding()

if encoding is None:
raise

return self._body.decode(self._guess_encoding(), errors=errors) # type: ignore[union-attr]
return self._body.decode(encoding, errors=errors) # type: ignore[union-attr]c

async def json(
self,
Expand Down

0 comments on commit df4fe75

Please sign in to comment.