Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linkcheck builder: begin using session-based HTTP requests #11503

Merged
19 changes: 12 additions & 7 deletions sphinx/builders/linkcheck.py
Expand Up @@ -279,12 +279,16 @@ def __init__(self, config: Config,
self.tls_verify = config.tls_verify
self.tls_cacerts = config.tls_cacerts

self._session = requests._Session()

super().__init__(daemon=True)

def run(self) -> None:
while True:
next_check, hyperlink = self.wqueue.get()
if hyperlink is None:
# An empty hyperlink is a signal to shutdown the worker; cleanup resources here
self._session.close()
break

uri, docname, _docpath, lineno = hyperlink
Expand Down Expand Up @@ -346,6 +350,13 @@ def _check(self, docname: str, uri: str, hyperlink: Hyperlink) -> tuple[str, str

return status, info, code

def _retrieval_methods(self,
check_anchors: bool,
anchor: str) -> Iterator[tuple[Callable, dict]]:
if not check_anchors or not anchor:
yield self._session.head, {'allow_redirects': True}
yield self._session.get, {'stream': True}

def _check_uri(self, uri: str, hyperlink: Hyperlink) -> tuple[str, str, int]:
req_url, delimiter, anchor = uri.partition('#')
for rex in self.anchors_ignore if delimiter and anchor else []:
Expand Down Expand Up @@ -377,7 +388,7 @@ def _check_uri(self, uri: str, hyperlink: Hyperlink) -> tuple[str, str, int]:
error_message = None
status_code = -1
response_url = retry_after = ''
for retrieval_method, kwargs in _retrieval_methods(self.check_anchors, anchor):
for retrieval_method, kwargs in self._retrieval_methods(self.check_anchors, anchor):
try:
with retrieval_method(
url=req_url, auth=auth_info,
Expand Down Expand Up @@ -508,12 +519,6 @@ def _get_request_headers(
return {}


def _retrieval_methods(check_anchors: bool, anchor: str) -> Iterator[tuple[Callable, dict]]:
if not check_anchors or not anchor:
yield requests.head, {'allow_redirects': True}
yield requests.get, {'stream': True}


def contains_anchor(response: Response, anchor: str) -> bool:
"""Determine if an anchor is contained within an HTTP response."""

Expand Down
35 changes: 35 additions & 0 deletions sphinx/util/requests.py
Expand Up @@ -77,3 +77,38 @@ def head(url: str,

with ignore_insecure_warning(verify):
return requests.head(url, **kwargs)


class _Session(requests.Session):

def get(self, url: str, **kwargs: Any) -> requests.Response: # type: ignore[override]
"""Sends a GET request like requests.get().

This sets up User-Agent header and TLS verification automatically."""
headers = kwargs.setdefault('headers', {})
headers.setdefault('User-Agent', kwargs.pop('_user_agent', _USER_AGENT))
jayaddison marked this conversation as resolved.
Show resolved Hide resolved
if '_tls_info' in kwargs:
tls_verify, tls_cacerts = kwargs.pop('_tls_info')
verify = bool(kwargs.get('verify', tls_verify))
kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts))
else:
verify = kwargs.get('verify', True)

with ignore_insecure_warning(verify):
return super().get(url, **kwargs)

def head(self, url: str, **kwargs: Any) -> requests.Response: # type: ignore[override]
"""Sends a HEAD request like requests.head().

This sets up User-Agent header and TLS verification automatically."""
headers = kwargs.setdefault('headers', {})
headers.setdefault('User-Agent', kwargs.pop('_user_agent', _USER_AGENT))
if '_tls_info' in kwargs:
tls_verify, tls_cacerts = kwargs.pop('_tls_info')
verify = bool(kwargs.get('verify', tls_verify))
kwargs.setdefault('verify', verify and _get_tls_cacert(url, tls_cacerts))
else:
verify = kwargs.get('verify', True)

with ignore_insecure_warning(verify):
return super().head(url, **kwargs)
2 changes: 1 addition & 1 deletion tests/test_build_linkcheck.py
Expand Up @@ -104,7 +104,7 @@ def test_defaults(app):
with http_server(DefaultsHandler):
with ConnectionMeasurement() as m:
app.build()
assert m.connection_count <= 10
assert m.connection_count <= 5

# Text output
assert (app.outdir / 'output.txt').exists()
Expand Down