Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add optional non blocking refresh for sync auth code #1368

Merged
merged 34 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a07551f
feat: Add optional non blocking refresh for sync auth code
clundin25 Aug 11, 2023
20f8ec4
Rename refresh_window to is_stale.
clundin25 Oct 31, 2023
377e19c
Add API to get background refresh errors.
clundin25 Nov 2, 2023
31e2f33
Apply error handling feedback.
clundin25 Nov 9, 2023
aab77bb
chore: Refresh system test creds.
clundin25 Nov 9, 2023
1848c2a
Fix linter.
clundin25 Nov 9, 2023
8d741d5
No cover issue.
clundin25 Nov 9, 2023
20941ed
Fix coverage test.
clundin25 Nov 10, 2023
3ca4b97
Fix linter.
clundin25 Nov 10, 2023
c7b11f8
chore: Refresh system test creds.
clundin25 Nov 10, 2023
0417312
Fixed linter error.
clundin25 Nov 10, 2023
64e1c7c
Merge branch 'main' into pre-emptive-refresh
BigTailWolf Nov 13, 2023
b94d6e1
Fix coverage test.
clundin25 Nov 13, 2023
88c00e1
Add no cover.
clundin25 Nov 13, 2023
e2623f4
Merge remote-tracking branch 'origin/pre-emptive-refresh' into pre-em…
clundin25 Nov 13, 2023
b5281bd
chore: Refresh system test creds.
clundin25 Nov 13, 2023
5d30915
One more cover pragma and run linter.
clundin25 Nov 13, 2023
4b146c4
Remove unnecessary code change.
clundin25 Nov 15, 2023
5476a8f
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Nov 15, 2023
77ff534
chore: Refresh system test creds.
clundin25 Nov 15, 2023
a19fe0d
PR feedback.
clundin25 Dec 14, 2023
0e35c67
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Dec 14, 2023
f0b218e
chore: Refresh system test creds.
clundin25 Dec 14, 2023
a6b1815
Fix lint
clundin25 Dec 14, 2023
df7220c
Use field to store error in RefreshThread.
clundin25 Dec 14, 2023
89ceeeb
Remove some left over constants. Combine locks.
clundin25 Dec 14, 2023
41042c2
PR feedback.
clundin25 Dec 14, 2023
f557447
Remove exception deadlock
clundin25 Dec 14, 2023
4cecc4a
Don't try to refresh stale tokens that have an active error.
clundin25 Dec 14, 2023
be2e2da
Fix cover check and logic error on `has_error()`
clundin25 Dec 14, 2023
d941394
PR feedback.
clundin25 Dec 15, 2023
f5e77e5
Deep copy request object when starting background refresh.
clundin25 Dec 15, 2023
be5f27e
PR feedback.
clundin25 Dec 18, 2023
fcf1f8a
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
112 changes: 27 additions & 85 deletions google/auth/_refresh_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,29 @@
# limitations under the License.

import logging
import queue
import threading

import google.auth.exceptions as e

WORKER_TIMEOUT_SECONDS = 5

_LOGGER = logging.getLogger(__name__)


class RefreshWorker:
class RefreshThreadManager:
"""
A worker that will perform a non-blocking refresh of credentials.
Organizes exactly one background job that refresh a token.
"""

MAX_REFRESH_QUEUE_SIZE = 1
MAX_ERROR_QUEUE_SIZE = 2

def __init__(self):
"""Initializes the worker."""
"""Initializes the manager."""

self._refresh_queue = queue.Queue(self.MAX_REFRESH_QUEUE_SIZE)
# Bound the error queue to avoid infinitely growing the heap.
self._error_queue = queue.Queue(self.MAX_ERROR_QUEUE_SIZE)
self._worker = None
self._lock = threading.Lock()
clundin25 marked this conversation as resolved.
Show resolved Hide resolved

def _need_worker(self):
return self._worker is None or not self._worker.is_alive()

def _spawn_worker(self):
self._worker = RefreshThread(
work_queue=self._refresh_queue, error_queue=self._error_queue
)
def _spawn_worker(self, cred, request):
self._worker = RefreshThread(cred=cred, request=request)
self._worker.start()

def start_refresh(self, cred, request):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is request immutable or can we create a copy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on why copy it? I don't think it is necessary, as the request should not be re-used in other places.

I think copying the request would be a breaking change.

I'd have to do some work to understand the implication of deep copying a request object.

Expand All @@ -62,101 +52,53 @@ def start_refresh(self, cred, request):
"Unable to start refresh. cred and request must be valid and instantiated objects."
)

clundin25 marked this conversation as resolved.
Show resolved Hide resolved
# This test case is covered by the unit tests but sometimes the cover
# check can flake due to the schdule.
#
# Specifially this test is covered by test_refresh_dead_worker
if not self._refresh_queue.empty(): # pragma: NO COVER
if self._need_worker():
self._spawn_worker()
return
with self._lock:
if self._worker is not None and self._worker._error_info is not None:
raise e.RefreshError(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has a kind of weird 'every 2nd call fails' pattern that doesn't seem quite right.

I would suggest we try to background refresh once, log and record the error, then do foreground refreshes from then on until we get a new token.

You might return false from start_refresh to cause the caller to call refresh on their own thread. The caller would be responsible for calling "clear error" any time they have a good token from refresh().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

f"Could not start a background refresh. The background refresh previously failed with {self._worker._error_info}."
) from self._worker._error_info

try:
self._refresh_queue.put_nowait((cred, request))
except queue.Full:
return

# This test case is covered by the unit tests but sometimes the cover
# check can flake due to the schdule.
if self._need_worker(): # pragma: NO COVER
self._spawn_worker()

def error_queue_full(self):
"""
True if the refresh worker error queue is full. False if it is not yet full.

Returns:
bool
"""
return self._error_queue.full()

def flush_error_queue(self):
"""
Drop all errors in the error queue.
"""
try:
while not self._error_queue.empty():
_ = self._error_queue.get_nowait()
# This condition is unlikely but there is a possibility that an
# error gets queued between the empty and get calls
except queue.Empty: # pragma: NO COVER
pass
if self._need_worker(): # pragma: NO COVER
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
self._spawn_worker(cred, request)

def get_error(self):
"""
Returns the first error in the error queue. It is recommended to flush the full error queue to root cause refresh failures.
Returns the error that occurred in the refresh thread. Clears the error once called.

Returns:
Optional[exceptions.Exception]
"""
try:
return self._error_queue.get_nowait()
except queue.Empty:
if not self._worker:
return None
err, self._worker._error_info = self._worker._error_info, None
return err


class RefreshThread(threading.Thread):
"""
Thread that refreshes credentials.
"""

def __init__(self, work_queue, error_queue, **kwargs):
def __init__(self, cred, request, **kwargs):
"""Initializes the thread.

Args:
work_queue: A queue of credentials and request tuples.
error_queue: A queue containing errors that prevented a credential refresh.
cred: A Credential object to refresh.
request: A Request object used to perform a credential refresh.
**kwargs: Additional keyword arguments.
"""

super().__init__(**kwargs)
self._work_queue = work_queue
self._error_queue = error_queue
self._cred = cred
self._request = request
self._error_info = None

def run(self):
"""
Gets credentials and request objects from a queue.

The thread will block until a work item appears from the queue.

Once the refresh has completed, the thread will mark the queue task
as complete, and exit.
Perform the credential refresh.
"""
try:
cred, request = self._work_queue.get(timeout=WORKER_TIMEOUT_SECONDS)
except queue.Empty:
_LOGGER.error(
f"Timed out waiting for refresh work after {WORKER_TIMEOUT_SECONDS} seconds. This could mean there is a race condition, work starvation, or other logic error in the refresh code."
)
return
try:
cred.refresh(request)
self._cred.refresh(self._request)
except Exception as err: # pragma: NO COVER
# This condition is covered by the unit test test_refresh_error but
# it can be flaky due to the scheduler.
_LOGGER.error(f"Background refresh failed due to: {err}")
if not self._error_queue.full():
self._error_queue.put_nowait(err)

# The coverage tool is not able to capturre this line, but it is covered
# by test_start_refresh in the unit tests.
self._work_queue.task_done()
self._error_info = err
67 changes: 44 additions & 23 deletions google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from google.auth import _helpers, environment_vars
from google.auth import exceptions
from google.auth import metrics
from google.auth._refresh_worker import RefreshWorker
from google.auth._refresh_worker import RefreshThreadManager


class Credentials(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(self):
"""

self._use_non_blocking_refresh = False
self._refresh_worker = RefreshWorker()
self._refresh_worker = RefreshThreadManager()

@property
def expired(self):
Expand All @@ -71,30 +71,45 @@ def expired(self):
Note that credentials can be invalid but not expired because
Credentials with :attr:`expiry` set to None is considered to never
expire.

.. deprecated:: v2.24.0
Prefer checking :attr:`token_state` instead.
"""
if not self.expiry:
return False

return _helpers.utcnow() >= self.expiry
# Remove some threshold from expiry to err on the side of reporting
# expiration early so that we avoid the 401-refresh-retry loop.
skewed_expiry = self.expiry - _helpers.REFRESH_THRESHOLD
return _helpers.utcnow() >= skewed_expiry

@property
def valid(self):
"""Checks the validity of the credentials.

This is True if the credentials have a :attr:`token` and the token
is not :attr:`expired`.

.. deprecated:: v2.24.0
Prefer checking :attr:`token_state` instead.
"""
return self.token is not None and not self.expired

@property
def token_state(self):
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
if not self.valid or self.expired:
"""
See `:obj:`TokenState`
"""
if self.token is None:
return TokenState.INVALID

# Credentials that can't expire are always treated as fresh.
if not self.expiry:
if self.expiry is None:
return TokenState.FRESH

expired = _helpers.utcnow() >= self.expiry
if expired:
return TokenState.INVALID

is_stale = _helpers.utcnow() >= (self.expiry - _helpers.REFRESH_THRESHOLD)
if is_stale:
return TokenState.STALE
Expand Down Expand Up @@ -171,6 +186,17 @@ def apply(self, headers, token=None):
if self.quota_project_id:
headers["x-goog-user-project"] = self.quota_project_id

def _blocking_refresh(self, request):
if not self.valid:
self.refresh(request)

def _non_blocking_refresh(self, request):
if self.token_state == TokenState.STALE:
self._refresh_worker.start_refresh(self, request)

if self.token_state == TokenState.INVALID:
self.refresh(request)

def before_request(self, request, method, url, headers):
"""Performs credential-specific before request logic.

Expand All @@ -188,21 +214,10 @@ def before_request(self, request, method, url, headers):
# pylint: disable=unused-argument
# (Subclasses may use these arguments to ascertain information about
# the http request.)

if self.token_state == TokenState.FRESH:
self._refresh_worker.flush_error_queue()

if self.token_state == TokenState.STALE:
if (
self._use_non_blocking_refresh
and not self._refresh_worker.error_queue_full()
):
self._refresh_worker.start_refresh(self, request)
else:
self.refresh(request)

if self.token_state == TokenState.INVALID:
self.refresh(request)
if self._use_non_blocking_refresh:
self._non_blocking_refresh(request)
else:
self._blocking_refresh(request)

metrics.add_metric_header(headers, self._metric_header_for_usage())
self.apply(headers)
Expand All @@ -212,9 +227,8 @@ def with_non_blocking_refresh(self):

def get_background_refresh_error(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer to remove this from the public API:

  • it is not relevant without the feature flag
  • it isn't clear when you'd want to call it
  • it's hard to implement safely

"""
Returns the first error in the background error queue. It is recommended to flush the full error queue to root cause refresh failures.
Returns the error in from a failed background refresh. Once called, the error will be flushed.

This error queue is populated by the token refreshes performed in a background thread.
Returns:
Optional[exceptions.Exception]
"""
Expand Down Expand Up @@ -469,6 +483,13 @@ def signer(self):


class TokenState(Enum):
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
"""
Tracks the state of a token.
FRESH: The token is not expired and can be used normally.
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
STALE: The token is close to expired, and should be refreshed. The token can be used normally.
INVALID: The token is expired or invalid. The token cannot be used for a normal operation.
"""

FRESH = 1
STALE = 2
INVALID = 3
3 changes: 1 addition & 2 deletions google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ def _update_token(self, request):

# Refresh our source credentials if it is not valid.
if (
not self._source_credentials.valid
or self._source_credentials.token_state == credentials.TokenState.STALE
self._source_credentials.token_state == credentials.TokenState.STALE
or self._source_credentials.token_state == credentials.TokenState.INVALID
):
self._source_credentials.refresh(request)
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.