-
Notifications
You must be signed in to change notification settings - Fork 296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add optional non blocking refresh for sync auth code #1368
Changes from 27 commits
a07551f
20f8ec4
377e19c
31e2f33
aab77bb
1848c2a
8d741d5
20941ed
3ca4b97
c7b11f8
0417312
64e1c7c
b94d6e1
88c00e1
e2623f4
b5281bd
5d30915
4b146c4
5476a8f
77ff534
a19fe0d
0e35c67
f0b218e
a6b1815
df7220c
89ceeeb
41042c2
f557447
4cecc4a
be2e2da
d941394
f5e77e5
be5f27e
fcf1f8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import threading | ||
|
||
import google.auth.exceptions as e | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class RefreshThreadManager: | ||
""" | ||
Organizes exactly one background job that refresh a token. | ||
""" | ||
|
||
def __init__(self): | ||
"""Initializes the manager.""" | ||
|
||
self._worker = None | ||
self._lock = threading.Lock() # protects access to worker threads. | ||
|
||
def start_refresh(self, cred, request): | ||
"""Starts a refresh thread for the given credentials. | ||
The credentials are refreshed using the request parameter. | ||
request and cred MUST not be None | ||
|
||
Args: | ||
cred: A credentials object. | ||
request: A request object. | ||
""" | ||
if cred is None or request is None: | ||
raise e.InvalidValue( | ||
"Unable to start refresh. cred and request must be valid and instantiated objects." | ||
) | ||
|
||
clundin25 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with self._lock: | ||
if self._worker is not None and self._worker._error_info is not None: | ||
raise e.RefreshError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this has a kind of weird 'every 2nd call fails' pattern that doesn't seem quite right. I would suggest we try to background refresh once, log and record the error, then do foreground refreshes from then on until we get a new token. You might return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! |
||
f"Could not start a background refresh. The background refresh previously failed with {self._worker._error_info}." | ||
) from self._worker._error_info | ||
|
||
if self._worker is None or not self._worker.is_alive(): # pragma: NO COVER | ||
self._worker = RefreshThread(cred=cred, request=request) | ||
self._worker.start() | ||
|
||
def get_error(self): | ||
""" | ||
Returns the error that occurred in the refresh thread. Clears the error once called. | ||
|
||
Returns: | ||
Optional[exceptions.Exception] | ||
""" | ||
if not self._worker: | ||
return None | ||
err, self._worker._error_info = self._worker._error_info, None | ||
return err | ||
|
||
|
||
class RefreshThread(threading.Thread): | ||
""" | ||
Thread that refreshes credentials. | ||
""" | ||
|
||
def __init__(self, cred, request, **kwargs): | ||
"""Initializes the thread. | ||
|
||
Args: | ||
cred: A Credential object to refresh. | ||
request: A Request object used to perform a credential refresh. | ||
**kwargs: Additional keyword arguments. | ||
""" | ||
|
||
super().__init__(**kwargs) | ||
self._cred = cred | ||
self._request = request | ||
self._error_info = None | ||
|
||
def run(self): | ||
""" | ||
Perform the credential refresh. | ||
""" | ||
try: | ||
self._cred.refresh(self._request) | ||
except Exception as err: # pragma: NO COVER | ||
_LOGGER.error(f"Background refresh failed due to: {err}") | ||
self._error_info = err |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,11 +16,13 @@ | |
"""Interfaces for credentials.""" | ||
|
||
import abc | ||
from enum import Enum | ||
import os | ||
|
||
from google.auth import _helpers, environment_vars | ||
from google.auth import exceptions | ||
from google.auth import metrics | ||
from google.auth._refresh_worker import RefreshThreadManager | ||
|
||
|
||
class Credentials(metaclass=abc.ABCMeta): | ||
|
@@ -59,17 +61,22 @@ def __init__(self): | |
"""Optional[str]: The universe domain value, default is googleapis.com | ||
""" | ||
|
||
self._use_non_blocking_refresh = False | ||
self._refresh_worker = RefreshThreadManager() | ||
|
||
@property | ||
def expired(self): | ||
"""Checks if the credentials are expired. | ||
|
||
Note that credentials can be invalid but not expired because | ||
Credentials with :attr:`expiry` set to None is considered to never | ||
expire. | ||
|
||
.. deprecated:: v2.24.0 | ||
Prefer checking :attr:`token_state` instead. | ||
""" | ||
if not self.expiry: | ||
return False | ||
|
||
# Remove some threshold from expiry to err on the side of reporting | ||
# expiration early so that we avoid the 401-refresh-retry loop. | ||
skewed_expiry = self.expiry - _helpers.REFRESH_THRESHOLD | ||
|
@@ -81,9 +88,34 @@ def valid(self): | |
|
||
This is True if the credentials have a :attr:`token` and the token | ||
is not :attr:`expired`. | ||
|
||
.. deprecated:: v2.24.0 | ||
Prefer checking :attr:`token_state` instead. | ||
""" | ||
return self.token is not None and not self.expired | ||
|
||
@property | ||
def token_state(self): | ||
clundin25 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
See `:obj:`TokenState` | ||
""" | ||
if self.token is None: | ||
return TokenState.INVALID | ||
|
||
# Credentials that can't expire are always treated as fresh. | ||
if self.expiry is None: | ||
return TokenState.FRESH | ||
|
||
expired = _helpers.utcnow() >= self.expiry | ||
if expired: | ||
return TokenState.INVALID | ||
|
||
is_stale = _helpers.utcnow() >= (self.expiry - _helpers.REFRESH_THRESHOLD) | ||
if is_stale: | ||
return TokenState.STALE | ||
|
||
return TokenState.FRESH | ||
|
||
@property | ||
def quota_project_id(self): | ||
"""Project to use for quota and billing purposes.""" | ||
|
@@ -154,6 +186,17 @@ def apply(self, headers, token=None): | |
if self.quota_project_id: | ||
headers["x-goog-user-project"] = self.quota_project_id | ||
|
||
def _blocking_refresh(self, request): | ||
if not self.valid: | ||
self.refresh(request) | ||
|
||
def _non_blocking_refresh(self, request): | ||
if self.token_state == TokenState.STALE: | ||
self._refresh_worker.start_refresh(self, request) | ||
|
||
if self.token_state == TokenState.INVALID: | ||
self.refresh(request) | ||
|
||
def before_request(self, request, method, url, headers): | ||
"""Performs credential-specific before request logic. | ||
|
||
|
@@ -171,11 +214,26 @@ def before_request(self, request, method, url, headers): | |
# pylint: disable=unused-argument | ||
# (Subclasses may use these arguments to ascertain information about | ||
# the http request.) | ||
if not self.valid: | ||
self.refresh(request) | ||
if self._use_non_blocking_refresh: | ||
self._non_blocking_refresh(request) | ||
else: | ||
self._blocking_refresh(request) | ||
|
||
metrics.add_metric_header(headers, self._metric_header_for_usage()) | ||
self.apply(headers) | ||
|
||
def with_non_blocking_refresh(self): | ||
self._use_non_blocking_refresh = True | ||
|
||
def get_background_refresh_error(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I would prefer to remove this from the public API:
|
||
""" | ||
Returns the error in from a failed background refresh. Once called, the error will be flushed. | ||
|
||
Returns: | ||
Optional[exceptions.Exception] | ||
""" | ||
return self._refresh_worker.get_error() | ||
|
||
|
||
class CredentialsWithQuotaProject(Credentials): | ||
"""Abstract base for credentials supporting ``with_quota_project`` factory""" | ||
|
@@ -422,3 +480,16 @@ def signer(self): | |
# pylint: disable=missing-raises-doc | ||
# (pylint doesn't recognize that this is abstract) | ||
raise NotImplementedError("Signer must be implemented.") | ||
|
||
|
||
class TokenState(Enum): | ||
clundin25 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Tracks the state of a token. | ||
FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry. To make it mutually exclusive to STALE. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry about this: the "To make it mutually exclusive to STALE" part wasn't meant to be part of the comment. |
||
STALE: The token is close to expired, and should be refreshed. The token can be used normally. | ||
INVALID: The token is expired or invalid. The token cannot be used for a normal operation. | ||
""" | ||
|
||
FRESH = 1 | ||
STALE = 2 | ||
INVALID = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
request
immutable or can we create a copy here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expand on why copy it? I don't think it is necessary, as the request should not be re-used in other places.
I think copying the request would be a breaking change.
I'd have to do some work to understand the implication of deep copying a request object.