diff --git a/sphinx/builders/linkcheck.py b/sphinx/builders/linkcheck.py index 428669349fa..b3777afdb0b 100644 --- a/sphinx/builders/linkcheck.py +++ b/sphinx/builders/linkcheck.py @@ -279,12 +279,16 @@ def __init__(self, config: Config, self.tls_verify = config.tls_verify self.tls_cacerts = config.tls_cacerts + self._session = requests._Session() + super().__init__(daemon=True) def run(self) -> None: while True: next_check, hyperlink = self.wqueue.get() if hyperlink is None: + # An empty hyperlink is a signal to shutdown the worker; cleanup resources here + self._session.close() break uri, docname, _docpath, lineno = hyperlink @@ -346,6 +350,13 @@ def _check(self, docname: str, uri: str, hyperlink: Hyperlink) -> tuple[str, str return status, info, code + def _retrieval_methods(self, + check_anchors: bool, + anchor: str) -> Iterator[tuple[Callable, dict]]: + if not check_anchors or not anchor: + yield self._session.head, {'allow_redirects': True} + yield self._session.get, {'stream': True} + def _check_uri(self, uri: str, hyperlink: Hyperlink) -> tuple[str, str, int]: req_url, delimiter, anchor = uri.partition('#') for rex in self.anchors_ignore if delimiter and anchor else []: @@ -377,7 +388,7 @@ def _check_uri(self, uri: str, hyperlink: Hyperlink) -> tuple[str, str, int]: error_message = None status_code = -1 response_url = retry_after = '' - for retrieval_method, kwargs in _retrieval_methods(self.check_anchors, anchor): + for retrieval_method, kwargs in self._retrieval_methods(self.check_anchors, anchor): try: with retrieval_method( url=req_url, auth=auth_info, @@ -508,12 +519,6 @@ def _get_request_headers( return {} -def _retrieval_methods(check_anchors: bool, anchor: str) -> Iterator[tuple[Callable, dict]]: - if not check_anchors or not anchor: - yield requests.head, {'allow_redirects': True} - yield requests.get, {'stream': True} - - def contains_anchor(response: Response, anchor: str) -> bool: """Determine if an anchor is contained within an HTTP response.""" diff --git a/sphinx/util/requests.py b/sphinx/util/requests.py index fb89d1237b7..ec3d8d2d5b8 100644 --- a/sphinx/util/requests.py +++ b/sphinx/util/requests.py @@ -3,8 +3,7 @@ from __future__ import annotations import warnings -from contextlib import contextmanager -from typing import Any, Iterator +from typing import Any from urllib.parse import urlsplit import requests @@ -16,15 +15,6 @@ f'Sphinx/{sphinx.__version__}') -@contextmanager -def ignore_insecure_warning(verify: bool) -> Iterator[None]: - with warnings.catch_warnings(): - if not verify: - # ignore InsecureRequestWarning if verify=False - warnings.filterwarnings("ignore", category=InsecureRequestWarning) - yield - - def _get_tls_cacert(url: str, certs: str | dict[str, str] | None) -> str | bool: """Get additional CA cert for a specific URL.""" if not certs: @@ -39,41 +29,45 @@ def _get_tls_cacert(url: str, certs: str | dict[str, str] | None) -> str | bool: return certs.get(hostname, True) -def get(url: str, - _user_agent: str = '', - _tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment] - **kwargs: Any) -> requests.Response: - """Sends a HEAD request like requests.head(). +def get(url: str, **kwargs: Any) -> requests.Response: + """Sends a GET request like requests.get(). This sets up User-Agent header and TLS verification automatically.""" - headers = kwargs.setdefault('headers', {}) - headers.setdefault('User-Agent', _user_agent or _USER_AGENT) - if _tls_info: - tls_verify, tls_cacerts = _tls_info - verify = bool(kwargs.get('verify', tls_verify)) - kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts)) - else: - verify = kwargs.get('verify', True) + with _Session() as session: + return session.get(url, **kwargs) - with ignore_insecure_warning(verify): - return requests.get(url, **kwargs) - -def head(url: str, - _user_agent: str = '', - _tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment] - **kwargs: Any) -> requests.Response: +def head(url: str, **kwargs: Any) -> requests.Response: """Sends a HEAD request like requests.head(). This sets up User-Agent header and TLS verification automatically.""" - headers = kwargs.setdefault('headers', {}) - headers.setdefault('User-Agent', _user_agent or _USER_AGENT) - if _tls_info: - tls_verify, tls_cacerts = _tls_info - verify = bool(kwargs.get('verify', tls_verify)) - kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts)) - else: - verify = kwargs.get('verify', True) + with _Session() as session: + return session.head(url, **kwargs) - with ignore_insecure_warning(verify): - return requests.head(url, **kwargs) + +class _Session(requests.Session): + def request( # type: ignore[override] + self, method: str, url: str, + _user_agent: str = '', + _tls_info: tuple[bool, str | dict[str, str] | None] = (), # type: ignore[assignment] + **kwargs: Any, + ) -> requests.Response: + """Sends a request with an HTTP verb and url. + + This sets up User-Agent header and TLS verification automatically.""" + headers = kwargs.setdefault('headers', {}) + headers.setdefault('User-Agent', _user_agent or _USER_AGENT) + if _tls_info: + tls_verify, tls_cacerts = _tls_info + verify = bool(kwargs.get('verify', tls_verify)) + kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts)) + else: + verify = kwargs.get('verify', True) + + if verify: + return super().request(method, url, **kwargs) + + with warnings.catch_warnings(): + # ignore InsecureRequestWarning if verify=False + warnings.filterwarnings("ignore", category=InsecureRequestWarning) + return super().request(method, url, **kwargs) diff --git a/tests/test_build_linkcheck.py b/tests/test_build_linkcheck.py index 0d032f947c8..cd0daaa50ae 100644 --- a/tests/test_build_linkcheck.py +++ b/tests/test_build_linkcheck.py @@ -104,7 +104,7 @@ def test_defaults(app): with http_server(DefaultsHandler): with ConnectionMeasurement() as m: app.build() - assert m.connection_count <= 10 + assert m.connection_count <= 5 # Text output assert (app.outdir / 'output.txt').exists()