Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linkcheck builder: begin using session-based HTTP requests #11503

Merged
19 changes: 12 additions & 7 deletions sphinx/builders/linkcheck.py
Expand Up @@ -279,12 +279,16 @@ def __init__(self, config: Config,
self.tls_verify = config.tls_verify
self.tls_cacerts = config.tls_cacerts

self._session = requests._Session()

super().__init__(daemon=True)

def run(self) -> None:
while True:
next_check, hyperlink = self.wqueue.get()
if hyperlink is None:
# An empty hyperlink is a signal to shutdown the worker; cleanup resources here
self._session.close()
break

uri, docname, _docpath, lineno = hyperlink
Expand Down Expand Up @@ -346,6 +350,13 @@ def _check(self, docname: str, uri: str, hyperlink: Hyperlink) -> tuple[str, str

return status, info, code

def _retrieval_methods(self,
check_anchors: bool,
anchor: str) -> Iterator[tuple[Callable, dict]]:
if not check_anchors or not anchor:
yield self._session.head, {'allow_redirects': True}
yield self._session.get, {'stream': True}

def _check_uri(self, uri: str, hyperlink: Hyperlink) -> tuple[str, str, int]:
req_url, delimiter, anchor = uri.partition('#')
for rex in self.anchors_ignore if delimiter and anchor else []:
Expand Down Expand Up @@ -377,7 +388,7 @@ def _check_uri(self, uri: str, hyperlink: Hyperlink) -> tuple[str, str, int]:
error_message = None
status_code = -1
response_url = retry_after = ''
for retrieval_method, kwargs in _retrieval_methods(self.check_anchors, anchor):
for retrieval_method, kwargs in self._retrieval_methods(self.check_anchors, anchor):
try:
with retrieval_method(
url=req_url, auth=auth_info,
Expand Down Expand Up @@ -508,12 +519,6 @@ def _get_request_headers(
return {}


def _retrieval_methods(check_anchors: bool, anchor: str) -> Iterator[tuple[Callable, dict]]:
if not check_anchors or not anchor:
yield requests.head, {'allow_redirects': True}
yield requests.get, {'stream': True}


def contains_anchor(response: Response, anchor: str) -> bool:
"""Determine if an anchor is contained within an HTTP response."""

Expand Down
76 changes: 35 additions & 41 deletions sphinx/util/requests.py
Expand Up @@ -3,8 +3,7 @@
from __future__ import annotations

import warnings
from contextlib import contextmanager
from typing import Any, Iterator
from typing import Any
from urllib.parse import urlsplit

import requests
Expand All @@ -16,15 +15,6 @@
f'Sphinx/{sphinx.__version__}')


@contextmanager
def ignore_insecure_warning(verify: bool) -> Iterator[None]:
with warnings.catch_warnings():
if not verify:
# ignore InsecureRequestWarning if verify=False
warnings.filterwarnings("ignore", category=InsecureRequestWarning)
yield


def _get_tls_cacert(url: str, certs: str | dict[str, str] | None) -> str | bool:
"""Get additional CA cert for a specific URL."""
if not certs:
Expand All @@ -39,41 +29,45 @@ def _get_tls_cacert(url: str, certs: str | dict[str, str] | None) -> str | bool:
return certs.get(hostname, True)


def get(url: str,
_user_agent: str = '',
_tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment]
**kwargs: Any) -> requests.Response:
"""Sends a HEAD request like requests.head().
def get(url: str, **kwargs: Any) -> requests.Response:
"""Sends a GET request like requests.get().

This sets up User-Agent header and TLS verification automatically."""
headers = kwargs.setdefault('headers', {})
headers.setdefault('User-Agent', _user_agent or _USER_AGENT)
if _tls_info:
tls_verify, tls_cacerts = _tls_info
verify = bool(kwargs.get('verify', tls_verify))
kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts))
else:
verify = kwargs.get('verify', True)
with _Session() as session:
return session.get(url, **kwargs)

with ignore_insecure_warning(verify):
return requests.get(url, **kwargs)


def head(url: str,
_user_agent: str = '',
_tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment]
**kwargs: Any) -> requests.Response:
def head(url: str, **kwargs: Any) -> requests.Response:
"""Sends a HEAD request like requests.head().

This sets up User-Agent header and TLS verification automatically."""
headers = kwargs.setdefault('headers', {})
headers.setdefault('User-Agent', _user_agent or _USER_AGENT)
if _tls_info:
tls_verify, tls_cacerts = _tls_info
verify = bool(kwargs.get('verify', tls_verify))
kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts))
else:
verify = kwargs.get('verify', True)
with _Session() as session:
return session.head(url, **kwargs)

with ignore_insecure_warning(verify):
return requests.head(url, **kwargs)

class _Session(requests.Session):
def request( # type: ignore[override]
self, method: str, url: str,
_user_agent: str = '',
_tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment]
**kwargs: Any,
) -> requests.Response:
"""Sends a request with an HTTP verb and url.

This sets up User-Agent header and TLS verification automatically."""
headers = kwargs.setdefault('headers', {})
headers.setdefault('User-Agent', _user_agent or _USER_AGENT)
if _tls_info:
tls_verify, tls_cacerts = _tls_info
verify = bool(kwargs.get('verify', tls_verify))
kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts))
else:
verify = kwargs.get('verify', True)

if verify:
return super().request(method, url, **kwargs)

with warnings.catch_warnings():
# ignore InsecureRequestWarning if verify=False
warnings.filterwarnings("ignore", category=InsecureRequestWarning)
return super().request(method, url, **kwargs)
2 changes: 1 addition & 1 deletion tests/test_build_linkcheck.py
Expand Up @@ -104,7 +104,7 @@ def test_defaults(app):
with http_server(DefaultsHandler):
with ConnectionMeasurement() as m:
app.build()
assert m.connection_count <= 10
assert m.connection_count <= 5

# Text output
assert (app.outdir / 'output.txt').exists()
Expand Down