Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add optional non blocking refresh for sync auth code #1368

Merged
merged 34 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a07551f
feat: Add optional non blocking refresh for sync auth code
clundin25 Aug 11, 2023
20f8ec4
Rename refresh_window to is_stale.
clundin25 Oct 31, 2023
377e19c
Add API to get background refresh errors.
clundin25 Nov 2, 2023
31e2f33
Apply error handling feedback.
clundin25 Nov 9, 2023
aab77bb
chore: Refresh system test creds.
clundin25 Nov 9, 2023
1848c2a
Fix linter.
clundin25 Nov 9, 2023
8d741d5
No cover issue.
clundin25 Nov 9, 2023
20941ed
Fix coverage test.
clundin25 Nov 10, 2023
3ca4b97
Fix linter.
clundin25 Nov 10, 2023
c7b11f8
chore: Refresh system test creds.
clundin25 Nov 10, 2023
0417312
Fixed linter error.
clundin25 Nov 10, 2023
64e1c7c
Merge branch 'main' into pre-emptive-refresh
BigTailWolf Nov 13, 2023
b94d6e1
Fix coverage test.
clundin25 Nov 13, 2023
88c00e1
Add no cover.
clundin25 Nov 13, 2023
e2623f4
Merge remote-tracking branch 'origin/pre-emptive-refresh' into pre-em…
clundin25 Nov 13, 2023
b5281bd
chore: Refresh system test creds.
clundin25 Nov 13, 2023
5d30915
One more cover pragma and run linter.
clundin25 Nov 13, 2023
4b146c4
Remove unnecessary code change.
clundin25 Nov 15, 2023
5476a8f
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Nov 15, 2023
77ff534
chore: Refresh system test creds.
clundin25 Nov 15, 2023
a19fe0d
PR feedback.
clundin25 Dec 14, 2023
0e35c67
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Dec 14, 2023
f0b218e
chore: Refresh system test creds.
clundin25 Dec 14, 2023
a6b1815
Fix lint
clundin25 Dec 14, 2023
df7220c
Use field to store error in RefreshThread.
clundin25 Dec 14, 2023
89ceeeb
Remove some left over constants. Combine locks.
clundin25 Dec 14, 2023
41042c2
PR feedback.
clundin25 Dec 14, 2023
f557447
Remove exception deadlock
clundin25 Dec 14, 2023
4cecc4a
Don't try to refresh stale tokens that have an active error.
clundin25 Dec 14, 2023
be2e2da
Fix cover check and logic error on `has_error()`
clundin25 Dec 14, 2023
d941394
PR feedback.
clundin25 Dec 15, 2023
f5e77e5
Deep copy request object when starting background refresh.
clundin25 Dec 15, 2023
be5f27e
PR feedback.
clundin25 Dec 18, 2023
fcf1f8a
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
162 changes: 162 additions & 0 deletions google/auth/_refresh_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import queue
import threading

import google.auth.exceptions as e

WORKER_TIMEOUT_SECONDS = 5

_LOGGER = logging.getLogger(__name__)


class RefreshWorker:
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
"""
A worker that will perform a non-blocking refresh of credentials.
"""

MAX_REFRESH_QUEUE_SIZE = 1
MAX_ERROR_QUEUE_SIZE = 2

def __init__(self):
"""Initializes the worker."""

self._refresh_queue = queue.Queue(self.MAX_REFRESH_QUEUE_SIZE)
# Bound the error queue to avoid infinitely growing the heap.
self._error_queue = queue.Queue(self.MAX_ERROR_QUEUE_SIZE)
self._worker = None

def _need_worker(self):
return self._worker is None or not self._worker.is_alive()

def _spawn_worker(self):
self._worker = RefreshThread(
work_queue=self._refresh_queue, error_queue=self._error_queue
)
self._worker.start()

def start_refresh(self, cred, request):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is request immutable or can we create a copy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on why copy it? I don't think it is necessary, as the request should not be re-used in other places.

I think copying the request would be a breaking change.

I'd have to do some work to understand the implication of deep copying a request object.

"""Starts a refresh thread for the given credentials.
The credentials are refreshed using the request parameter.
request and cred MUST not be None

Args:
cred: A credentials object.
request: A request object.
"""
if cred is None or request is None:
raise e.InvalidValue(
"Unable to start refresh. cred and request must be valid and instantiated objects."
)

clundin25 marked this conversation as resolved.
Show resolved Hide resolved
# This test case is covered by the unit tests but sometimes the cover
# check can flake due to the schdule.
#
# Specifially this test is covered by test_refresh_dead_worker
if not self._refresh_queue.empty(): # pragma: NO COVER
if self._need_worker():
self._spawn_worker()
return

try:
self._refresh_queue.put_nowait((cred, request))
except queue.Full:
return

# This test case is covered by the unit tests but sometimes the cover
# check can flake due to the schdule.
if self._need_worker(): # pragma: NO COVER
self._spawn_worker()

def error_queue_full(self):
"""
True if the refresh worker error queue is full. False if it is not yet full.

Returns:
bool
"""
return self._error_queue.full()

def flush_error_queue(self):
"""
Drop all errors in the error queue.
"""
try:
while not self._error_queue.empty():
_ = self._error_queue.get_nowait()
# This condition is unlikely but there is a possibility that an
# error gets queued between the empty and get calls
except queue.Empty: # pragma: NO COVER
pass

def get_error(self):
"""
Returns the first error in the error queue. It is recommended to flush the full error queue to root cause refresh failures.
Returns:
Optional[exceptions.Exception]
"""
try:
return self._error_queue.get_nowait()
except queue.Empty:
return None


class RefreshThread(threading.Thread):
"""
Thread that refreshes credentials.
"""

def __init__(self, work_queue, error_queue, **kwargs):
"""Initializes the thread.

Args:
work_queue: A queue of credentials and request tuples.
error_queue: A queue containing errors that prevented a credential refresh.
**kwargs: Additional keyword arguments.
"""

super().__init__(**kwargs)
self._work_queue = work_queue
self._error_queue = error_queue

def run(self):
"""
Gets credentials and request objects from a queue.

The thread will block until a work item appears from the queue.

Once the refresh has completed, the thread will mark the queue task
as complete, and exit.
"""
try:
cred, request = self._work_queue.get(timeout=WORKER_TIMEOUT_SECONDS)
except queue.Empty:
_LOGGER.error(
f"Timed out waiting for refresh work after {WORKER_TIMEOUT_SECONDS} seconds. This could mean there is a race condition, work starvation, or other logic error in the refresh code."
)
return
try:
cred.refresh(request)
except Exception as err: # pragma: NO COVER
# This condition is covered by the unit test test_refresh_error but
# it can be flaky due to the scheduler.
_LOGGER.error(f"Background refresh failed due to: {err}")
if not self._error_queue.full():
self._error_queue.put_nowait(err)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the code above, you can just stuff the error into a field.


# The coverage tool is not able to capturre this line, but it is covered
# by test_start_refresh in the unit tests.
self._work_queue.task_done()
60 changes: 55 additions & 5 deletions google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
"""Interfaces for credentials."""

import abc
from enum import Enum
import os

from google.auth import _helpers, environment_vars
from google.auth import exceptions
from google.auth import metrics
from google.auth._refresh_worker import RefreshWorker


class Credentials(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -59,6 +61,9 @@ def __init__(self):
"""Optional[str]: The universe domain value, default is googleapis.com
"""

self._use_non_blocking_refresh = False
self._refresh_worker = RefreshWorker()

@property
def expired(self):
"""Checks if the credentials are expired.
Expand All @@ -70,10 +75,7 @@ def expired(self):
if not self.expiry:
return False

# Remove some threshold from expiry to err on the side of reporting
# expiration early so that we avoid the 401-refresh-retry loop.
skewed_expiry = self.expiry - _helpers.REFRESH_THRESHOLD
return _helpers.utcnow() >= skewed_expiry
return _helpers.utcnow() >= self.expiry
clundin25 marked this conversation as resolved.
Show resolved Hide resolved

@property
def valid(self):
Expand All @@ -84,6 +86,21 @@ def valid(self):
"""
return self.token is not None and not self.expired

@property
def token_state(self):
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
if not self.valid or self.expired:
return TokenState.INVALID

# Credentials that can't expire are always treated as fresh.
if not self.expiry:
return TokenState.FRESH

is_stale = _helpers.utcnow() >= (self.expiry - _helpers.REFRESH_THRESHOLD)
if is_stale:
return TokenState.STALE

return TokenState.FRESH

@property
def quota_project_id(self):
"""Project to use for quota and billing purposes."""
Expand Down Expand Up @@ -171,11 +188,38 @@ def before_request(self, request, method, url, headers):
# pylint: disable=unused-argument
# (Subclasses may use these arguments to ascertain information about
# the http request.)
if not self.valid:

if self.token_state == TokenState.FRESH:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section doesn't guard access to refresh worker resources with your feature flag.

Maybe something like:

if _use_non_blocking_refresh():
  _non_blocking_refresh() # thread maintenance, stale handling, refresh on invalid
else
  _blocking_refresh() # refresh on stale or invalid

self._refresh_worker.flush_error_queue()

if self.token_state == TokenState.STALE:
if (
self._use_non_blocking_refresh
and not self._refresh_worker.error_queue_full()
):
self._refresh_worker.start_refresh(self, request)
else:
self.refresh(request)

if self.token_state == TokenState.INVALID:
self.refresh(request)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiple requests may end up doing this refresh right?
Can we have a lock on this sync refresh?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is using a bounded queue which should reduce duplicate work. There is a risk that there is a small amount of duplicate work, if two threads queue work at the same time. The queue itself is thread safe + has locks.

I believe using the queue to reduce duplicate work is acceptable, and that multiple threads will perform the refresh occasionally.

There should be no data corruption due to this


metrics.add_metric_header(headers, self._metric_header_for_usage())
self.apply(headers)

def with_non_blocking_refresh(self):
self._use_non_blocking_refresh = True

def get_background_refresh_error(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer to remove this from the public API:

  • it is not relevant without the feature flag
  • it isn't clear when you'd want to call it
  • it's hard to implement safely

"""
Returns the first error in the background error queue. It is recommended to flush the full error queue to root cause refresh failures.

This error queue is populated by the token refreshes performed in a background thread.
Returns:
Optional[exceptions.Exception]
"""
return self._refresh_worker.get_error()


class CredentialsWithQuotaProject(Credentials):
"""Abstract base for credentials supporting ``with_quota_project`` factory"""
Expand Down Expand Up @@ -422,3 +466,9 @@ def signer(self):
# pylint: disable=missing-raises-doc
# (pylint doesn't recognize that this is abstract)
raise NotImplementedError("Signer must be implemented.")


class TokenState(Enum):
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
FRESH = 1
STALE = 2
INVALID = 3
6 changes: 5 additions & 1 deletion google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,11 @@ def _update_token(self, request):
"""

# Refresh our source credentials if it is not valid.
if not self._source_credentials.valid:
if (
not self._source_credentials.valid

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this dependency and juse use TokenState.INVALID?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I've made this change

or self._source_credentials.token_state == credentials.TokenState.STALE
or self._source_credentials.token_state == credentials.TokenState.INVALID
):
self._source_credentials.refresh(request)

body = {
Expand Down
4 changes: 4 additions & 0 deletions google/oauth2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def __getstate__(self):
# because they need to be importable.
# Instead, the refresh_handler setter should be used to repopulate this.
del state_dict["_refresh_handler"]
# Remove worker as it contains multiproccessing queue objects.
del state_dict["_refresh_worker"]
return state_dict

def __setstate__(self, d):
Expand All @@ -177,6 +179,8 @@ def __setstate__(self, d):
self._universe_domain = d.get("_universe_domain")
# The refresh_handler setter should be used to repopulate this.
self._refresh_handler = None
self._refresh_worker = None
self._use_non_blocking_refresh = d.get("_use_non_blocking_refresh")

@property
def refresh_token(self):
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
10 changes: 8 additions & 2 deletions tests/oauth2/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.auth import _helpers
from google.auth import exceptions
from google.auth import transport
from google.auth.credentials import TokenState
from google.oauth2 import credentials


Expand Down Expand Up @@ -61,6 +62,7 @@ def test_default_state(self):
assert not credentials.expired
# Scopes aren't required for these credentials
assert not credentials.requires_scopes
assert credentials.token_state == TokenState.INVALID
# Test properties
assert credentials.refresh_token == self.REFRESH_TOKEN
assert credentials.token_uri == self.TOKEN_URI
Expand Down Expand Up @@ -893,7 +895,11 @@ def test_pickle_and_unpickle(self):
assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()

for attr in list(creds.__dict__):
assert getattr(creds, attr) == getattr(unpickled, attr)
# Worker should always be None
if attr == "_refresh_worker":
assert getattr(unpickled, attr) is None
else:
assert getattr(creds, attr) == getattr(unpickled, attr)

def test_pickle_and_unpickle_with_refresh_handler(self):
expected_expiry = _helpers.utcnow() + datetime.timedelta(seconds=2800)
Expand All @@ -916,7 +922,7 @@ def test_pickle_and_unpickle_with_refresh_handler(self):
for attr in list(creds.__dict__):
# For the _refresh_handler property, the unpickled creds should be
# set to None.
if attr == "_refresh_handler":
if attr == "_refresh_handler" or attr == "_refresh_worker":
assert getattr(unpickled, attr) is None
else:
assert getattr(creds, attr) == getattr(unpickled, attr)
Expand Down