Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add optional non blocking refresh for sync auth code #1368

Merged
merged 34 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a07551f
feat: Add optional non blocking refresh for sync auth code
clundin25 Aug 11, 2023
20f8ec4
Rename refresh_window to is_stale.
clundin25 Oct 31, 2023
377e19c
Add API to get background refresh errors.
clundin25 Nov 2, 2023
31e2f33
Apply error handling feedback.
clundin25 Nov 9, 2023
aab77bb
chore: Refresh system test creds.
clundin25 Nov 9, 2023
1848c2a
Fix linter.
clundin25 Nov 9, 2023
8d741d5
No cover issue.
clundin25 Nov 9, 2023
20941ed
Fix coverage test.
clundin25 Nov 10, 2023
3ca4b97
Fix linter.
clundin25 Nov 10, 2023
c7b11f8
chore: Refresh system test creds.
clundin25 Nov 10, 2023
0417312
Fixed linter error.
clundin25 Nov 10, 2023
64e1c7c
Merge branch 'main' into pre-emptive-refresh
BigTailWolf Nov 13, 2023
b94d6e1
Fix coverage test.
clundin25 Nov 13, 2023
88c00e1
Add no cover.
clundin25 Nov 13, 2023
e2623f4
Merge remote-tracking branch 'origin/pre-emptive-refresh' into pre-em…
clundin25 Nov 13, 2023
b5281bd
chore: Refresh system test creds.
clundin25 Nov 13, 2023
5d30915
One more cover pragma and run linter.
clundin25 Nov 13, 2023
4b146c4
Remove unnecessary code change.
clundin25 Nov 15, 2023
5476a8f
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Nov 15, 2023
77ff534
chore: Refresh system test creds.
clundin25 Nov 15, 2023
a19fe0d
PR feedback.
clundin25 Dec 14, 2023
0e35c67
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Dec 14, 2023
f0b218e
chore: Refresh system test creds.
clundin25 Dec 14, 2023
a6b1815
Fix lint
clundin25 Dec 14, 2023
df7220c
Use field to store error in RefreshThread.
clundin25 Dec 14, 2023
89ceeeb
Remove some left over constants. Combine locks.
clundin25 Dec 14, 2023
41042c2
PR feedback.
clundin25 Dec 14, 2023
f557447
Remove exception deadlock
clundin25 Dec 14, 2023
4cecc4a
Don't try to refresh stale tokens that have an active error.
clundin25 Dec 14, 2023
be2e2da
Fix cover check and logic error on `has_error()`
clundin25 Dec 14, 2023
d941394
PR feedback.
clundin25 Dec 15, 2023
f5e77e5
Deep copy request object when starting background refresh.
clundin25 Dec 15, 2023
be5f27e
PR feedback.
clundin25 Dec 18, 2023
fcf1f8a
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 104 additions & 0 deletions google/auth/_refresh_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading

import google.auth.exceptions as e

_LOGGER = logging.getLogger(__name__)


class RefreshThreadManager:
"""
Organizes exactly one background job that refresh a token.
"""

def __init__(self):
"""Initializes the manager."""

self._worker = None
self._lock = threading.Lock()
clundin25 marked this conversation as resolved.
Show resolved Hide resolved

def _need_worker(self):
return self._worker is None or not self._worker.is_alive()

def _spawn_worker(self, cred, request):
self._worker = RefreshThread(cred=cred, request=request)
self._worker.start()

def start_refresh(self, cred, request):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is request immutable or can we create a copy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on why copy it? I don't think it is necessary, as the request should not be re-used in other places.

I think copying the request would be a breaking change.

I'd have to do some work to understand the implication of deep copying a request object.

"""Starts a refresh thread for the given credentials.
The credentials are refreshed using the request parameter.
request and cred MUST not be None

Args:
cred: A credentials object.
request: A request object.
"""
if cred is None or request is None:
raise e.InvalidValue(
"Unable to start refresh. cred and request must be valid and instantiated objects."
)

clundin25 marked this conversation as resolved.
Show resolved Hide resolved
with self._lock:
if self._worker is not None and self._worker._error_info is not None:
raise e.RefreshError(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has a kind of weird 'every 2nd call fails' pattern that doesn't seem quite right.

I would suggest we try to background refresh once, log and record the error, then do foreground refreshes from then on until we get a new token.

You might return false from start_refresh to cause the caller to call refresh on their own thread. The caller would be responsible for calling "clear error" any time they have a good token from refresh().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

f"Could not start a background refresh. The background refresh previously failed with {self._worker._error_info}."
) from self._worker._error_info

if self._need_worker(): # pragma: NO COVER
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
self._spawn_worker(cred, request)

def get_error(self):
"""
Returns the error that occurred in the refresh thread. Clears the error once called.

Returns:
Optional[exceptions.Exception]
"""
if not self._worker:
return None
err, self._worker._error_info = self._worker._error_info, None
return err


class RefreshThread(threading.Thread):
"""
Thread that refreshes credentials.
"""

def __init__(self, cred, request, **kwargs):
"""Initializes the thread.

Args:
cred: A Credential object to refresh.
request: A Request object used to perform a credential refresh.
**kwargs: Additional keyword arguments.
"""

super().__init__(**kwargs)
self._cred = cred
self._request = request
self._error_info = None

def run(self):
"""
Perform the credential refresh.
"""
try:
self._cred.refresh(self._request)
except Exception as err: # pragma: NO COVER
_LOGGER.error(f"Background refresh failed due to: {err}")
self._error_info = err
77 changes: 74 additions & 3 deletions google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
"""Interfaces for credentials."""

import abc
from enum import Enum
import os

from google.auth import _helpers, environment_vars
from google.auth import exceptions
from google.auth import metrics
from google.auth._refresh_worker import RefreshThreadManager


class Credentials(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -59,17 +61,22 @@ def __init__(self):
"""Optional[str]: The universe domain value, default is googleapis.com
"""

self._use_non_blocking_refresh = False
self._refresh_worker = RefreshThreadManager()

@property
def expired(self):
"""Checks if the credentials are expired.

Note that credentials can be invalid but not expired because
Credentials with :attr:`expiry` set to None is considered to never
expire.

.. deprecated:: v2.24.0
Prefer checking :attr:`token_state` instead.
"""
if not self.expiry:
return False

# Remove some threshold from expiry to err on the side of reporting
# expiration early so that we avoid the 401-refresh-retry loop.
skewed_expiry = self.expiry - _helpers.REFRESH_THRESHOLD
Expand All @@ -81,9 +88,34 @@ def valid(self):

This is True if the credentials have a :attr:`token` and the token
is not :attr:`expired`.

.. deprecated:: v2.24.0
Prefer checking :attr:`token_state` instead.
"""
return self.token is not None and not self.expired

@property
def token_state(self):
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
"""
See `:obj:`TokenState`
"""
if self.token is None:
return TokenState.INVALID

# Credentials that can't expire are always treated as fresh.
if self.expiry is None:
return TokenState.FRESH

expired = _helpers.utcnow() >= self.expiry
if expired:
return TokenState.INVALID

is_stale = _helpers.utcnow() >= (self.expiry - _helpers.REFRESH_THRESHOLD)
if is_stale:
return TokenState.STALE

return TokenState.FRESH

@property
def quota_project_id(self):
"""Project to use for quota and billing purposes."""
Expand Down Expand Up @@ -154,6 +186,17 @@ def apply(self, headers, token=None):
if self.quota_project_id:
headers["x-goog-user-project"] = self.quota_project_id

def _blocking_refresh(self, request):
if not self.valid:
self.refresh(request)

def _non_blocking_refresh(self, request):
if self.token_state == TokenState.STALE:
self._refresh_worker.start_refresh(self, request)

if self.token_state == TokenState.INVALID:
self.refresh(request)

def before_request(self, request, method, url, headers):
"""Performs credential-specific before request logic.

Expand All @@ -171,11 +214,26 @@ def before_request(self, request, method, url, headers):
# pylint: disable=unused-argument
# (Subclasses may use these arguments to ascertain information about
# the http request.)
if not self.valid:
self.refresh(request)
if self._use_non_blocking_refresh:
self._non_blocking_refresh(request)
else:
self._blocking_refresh(request)

metrics.add_metric_header(headers, self._metric_header_for_usage())
self.apply(headers)

def with_non_blocking_refresh(self):
self._use_non_blocking_refresh = True

def get_background_refresh_error(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer to remove this from the public API:

  • it is not relevant without the feature flag
  • it isn't clear when you'd want to call it
  • it's hard to implement safely

"""
Returns the error in from a failed background refresh. Once called, the error will be flushed.

Returns:
Optional[exceptions.Exception]
"""
return self._refresh_worker.get_error()


class CredentialsWithQuotaProject(Credentials):
"""Abstract base for credentials supporting ``with_quota_project`` factory"""
Expand Down Expand Up @@ -422,3 +480,16 @@ def signer(self):
# pylint: disable=missing-raises-doc
# (pylint doesn't recognize that this is abstract)
raise NotImplementedError("Signer must be implemented.")


class TokenState(Enum):
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
"""
Tracks the state of a token.
FRESH: The token is not expired and can be used normally.
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
STALE: The token is close to expired, and should be refreshed. The token can be used normally.
INVALID: The token is expired or invalid. The token cannot be used for a normal operation.
"""

FRESH = 1
STALE = 2
INVALID = 3
5 changes: 4 additions & 1 deletion google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def _update_token(self, request):
"""

# Refresh our source credentials if it is not valid.
if not self._source_credentials.valid:
if (
self._source_credentials.token_state == credentials.TokenState.STALE
or self._source_credentials.token_state == credentials.TokenState.INVALID
):
self._source_credentials.refresh(request)

body = {
Expand Down
4 changes: 4 additions & 0 deletions google/oauth2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def __getstate__(self):
# because they need to be importable.
# Instead, the refresh_handler setter should be used to repopulate this.
del state_dict["_refresh_handler"]
# Remove worker as it contains multiproccessing queue objects.
del state_dict["_refresh_worker"]
return state_dict

def __setstate__(self, d):
Expand All @@ -183,6 +185,8 @@ def __setstate__(self, d):
self._universe_domain = d.get("_universe_domain") or _DEFAULT_UNIVERSE_DOMAIN
# The refresh_handler setter should be used to repopulate this.
self._refresh_handler = None
self._refresh_worker = None
self._use_non_blocking_refresh = d.get("_use_non_blocking_refresh")

@property
def refresh_token(self):
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
10 changes: 8 additions & 2 deletions tests/oauth2/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.auth import _helpers
from google.auth import exceptions
from google.auth import transport
from google.auth.credentials import TokenState
from google.oauth2 import credentials


Expand Down Expand Up @@ -61,6 +62,7 @@ def test_default_state(self):
assert not credentials.expired
# Scopes aren't required for these credentials
assert not credentials.requires_scopes
assert credentials.token_state == TokenState.INVALID
# Test properties
assert credentials.refresh_token == self.REFRESH_TOKEN
assert credentials.token_uri == self.TOKEN_URI
Expand Down Expand Up @@ -911,7 +913,11 @@ def test_pickle_and_unpickle(self):
assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()

for attr in list(creds.__dict__):
assert getattr(creds, attr) == getattr(unpickled, attr)
# Worker should always be None
if attr == "_refresh_worker":
assert getattr(unpickled, attr) is None
else:
assert getattr(creds, attr) == getattr(unpickled, attr)

def test_pickle_and_unpickle_universe_domain(self):
# old version of auth lib doesn't have _universe_domain, so the pickled
Expand Down Expand Up @@ -945,7 +951,7 @@ def test_pickle_and_unpickle_with_refresh_handler(self):
for attr in list(creds.__dict__):
# For the _refresh_handler property, the unpickled creds should be
# set to None.
if attr == "_refresh_handler":
if attr == "_refresh_handler" or attr == "_refresh_worker":
assert getattr(unpickled, attr) is None
else:
assert getattr(creds, attr) == getattr(unpickled, attr)
Expand Down