Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add optional non blocking refresh for sync auth code #1368

Merged
merged 34 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a07551f
feat: Add optional non blocking refresh for sync auth code
clundin25 Aug 11, 2023
20f8ec4
Rename refresh_window to is_stale.
clundin25 Oct 31, 2023
377e19c
Add API to get background refresh errors.
clundin25 Nov 2, 2023
31e2f33
Apply error handling feedback.
clundin25 Nov 9, 2023
aab77bb
chore: Refresh system test creds.
clundin25 Nov 9, 2023
1848c2a
Fix linter.
clundin25 Nov 9, 2023
8d741d5
No cover issue.
clundin25 Nov 9, 2023
20941ed
Fix coverage test.
clundin25 Nov 10, 2023
3ca4b97
Fix linter.
clundin25 Nov 10, 2023
c7b11f8
chore: Refresh system test creds.
clundin25 Nov 10, 2023
0417312
Fixed linter error.
clundin25 Nov 10, 2023
64e1c7c
Merge branch 'main' into pre-emptive-refresh
BigTailWolf Nov 13, 2023
b94d6e1
Fix coverage test.
clundin25 Nov 13, 2023
88c00e1
Add no cover.
clundin25 Nov 13, 2023
e2623f4
Merge remote-tracking branch 'origin/pre-emptive-refresh' into pre-em…
clundin25 Nov 13, 2023
b5281bd
chore: Refresh system test creds.
clundin25 Nov 13, 2023
5d30915
One more cover pragma and run linter.
clundin25 Nov 13, 2023
4b146c4
Remove unnecessary code change.
clundin25 Nov 15, 2023
5476a8f
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Nov 15, 2023
77ff534
chore: Refresh system test creds.
clundin25 Nov 15, 2023
a19fe0d
PR feedback.
clundin25 Dec 14, 2023
0e35c67
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Dec 14, 2023
f0b218e
chore: Refresh system test creds.
clundin25 Dec 14, 2023
a6b1815
Fix lint
clundin25 Dec 14, 2023
df7220c
Use field to store error in RefreshThread.
clundin25 Dec 14, 2023
89ceeeb
Remove some left over constants. Combine locks.
clundin25 Dec 14, 2023
41042c2
PR feedback.
clundin25 Dec 14, 2023
f557447
Remove exception deadlock
clundin25 Dec 14, 2023
4cecc4a
Don't try to refresh stale tokens that have an active error.
clundin25 Dec 14, 2023
be2e2da
Fix cover check and logic error on `has_error()`
clundin25 Dec 14, 2023
d941394
PR feedback.
clundin25 Dec 15, 2023
f5e77e5
Deep copy request object when starting background refresh.
clundin25 Dec 15, 2023
be5f27e
PR feedback.
clundin25 Dec 18, 2023
fcf1f8a
Merge remote-tracking branch 'upstream/main' into pre-emptive-refresh
clundin25 Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 13 additions & 29 deletions google/auth/_refresh_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import threading

Expand All @@ -36,9 +37,13 @@ def start_refresh(self, cred, request):
The credentials are refreshed using the request parameter.
request and cred MUST not be None

Returns True if a background refresh was kicked off. False otherwise.

Args:
cred: A credentials object.
request: A request object.
Returns:
bool
"""
if cred is None or request is None:
raise e.InvalidValue(
Expand All @@ -47,40 +52,19 @@ def start_refresh(self, cred, request):

clundin25 marked this conversation as resolved.
Show resolved Hide resolved
with self._lock:
if self._worker is not None and self._worker._error_info is not None:
# Reset error field to prevent deadlock for clients that try to
# rety this error.
err, self._worker._error_info = self._worker._error_info, None

raise e.RefreshError(
f"Could not start a background refresh. The background refresh previously failed with {self._worker._error_info}."
) from err
return False

if self._worker is None or not self._worker.is_alive(): # pragma: NO COVER
self._worker = RefreshThread(cred=cred, request=request)
self._worker = RefreshThread(cred=cred, request=copy.deepcopy(request))
self._worker.start()
return True

def has_error(self):
"""
Returns True if a refresh thread has had an exception, and the exception has not been cleared.

Returns:
Optional[Boolean]
"""
if not self._worker:
return False
return self._worker._error_info is not None

def get_error(self):
"""
Returns the error that occurred in the refresh thread. Clears the error once called.

Returns:
Optional[exceptions.Exception]
def clear_error(self):
"""
if not self._worker:
return None
err, self._worker._error_info = self._worker._error_info, None
return err
Removes any errors that were stored from previous background refreshes.
"""
if self._worker:
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
self._worker._error_info = None


class RefreshThread(threading.Thread):
Expand Down
25 changes: 9 additions & 16 deletions google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,16 @@ def _blocking_refresh(self, request):
self.refresh(request)

def _non_blocking_refresh(self, request):
if (
self.token_state == TokenState.STALE
and not self._refresh_worker.has_error()
):
self._refresh_worker.start_refresh(self, request)
use_blocking_refresh_fallback = False

if self.token_state == TokenState.INVALID:
if self.token_state == TokenState.STALE:
use_blocking_refresh_fallback = not self._refresh_worker.start_refresh(
self, request
)

if self.token_state == TokenState.INVALID or use_blocking_refresh_fallback:
self.refresh(request)
self._refresh_worker.clear_error()
clundin25 marked this conversation as resolved.
Show resolved Hide resolved

def before_request(self, request, method, url, headers):
"""Performs credential-specific before request logic.
Expand Down Expand Up @@ -228,15 +230,6 @@ def before_request(self, request, method, url, headers):
def with_non_blocking_refresh(self):
self._use_non_blocking_refresh = True

def get_background_refresh_error(self):
"""
Returns the error in from a failed background refresh. Once called, the error will be flushed.

Returns:
Optional[exceptions.Exception]
"""
return self._refresh_worker.get_error()


class CredentialsWithQuotaProject(Credentials):
"""Abstract base for credentials supporting ``with_quota_project`` factory"""
Expand Down Expand Up @@ -488,7 +481,7 @@ def signer(self):
class TokenState(Enum):
clundin25 marked this conversation as resolved.
Show resolved Hide resolved
"""
Tracks the state of a token.
FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry. To make it mutually exclusive to STALE.
FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry.
STALE: The token is close to expired, and should be refreshed. The token can be used normally.
INVALID: The token is expired or invalid. The token cannot be used for a normal operation.
"""
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
45 changes: 8 additions & 37 deletions tests/test__refresh_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_start_refresh():
w = _refresh_worker.RefreshThreadManager()
cred = MockCredentialsImpl()
request = mock.MagicMock()
w.start_refresh(cred, request)
assert w.start_refresh(cred, request)

assert w._worker is not None

Expand All @@ -71,7 +71,7 @@ def test_nonblocking_start_refresh():
w = _refresh_worker.RefreshThreadManager()
cred = MockCredentialsImpl(sleep_seconds=1)
request = mock.MagicMock()
w.start_refresh(cred, request)
assert w.start_refresh(cred, request)

assert w._worker is not None
assert not cred.token
Expand All @@ -85,7 +85,7 @@ def test_multiple_refreshes_multiple_workers(test_thread_count):

def _thread_refresh():
time.sleep(random.randrange(0, 5))
w.start_refresh(cred, request)
assert w.start_refresh(cred, request)

threads = [
threading.Thread(target=_thread_refresh) for _ in range(test_thread_count)
Expand All @@ -108,17 +108,13 @@ def test_refresh_error():

cred.refresh.side_effect = exceptions.RefreshError("Failed to refresh")

w.start_refresh(cred, request)
assert w.start_refresh(cred, request)

err = None
while err is None:
err = w.get_error()
while w._worker._error_info is None: # pragma: NO COVER
time.sleep(MAIN_THREAD_SLEEP_MS)

assert w._worker is not None
assert isinstance(err, exceptions.RefreshError)
assert w._worker._error_info is None
assert w.get_error() is None
assert isinstance(w._worker._error_info, exceptions.RefreshError)


def test_refresh_error_call_refresh_again():
Expand All @@ -128,14 +124,12 @@ def test_refresh_error_call_refresh_again():

cred.refresh.side_effect = exceptions.RefreshError("Failed to refresh")

w.start_refresh(cred, request)
assert w.start_refresh(cred, request)

while w._worker._error_info is None: # pragma: NO COVER
time.sleep(MAIN_THREAD_SLEEP_MS)

with pytest.raises(exceptions.RefreshError) as excinfo:
w.start_refresh(cred, request)
excinfo.match("Could not start a background refresh.*")
assert not w.start_refresh(cred, request)


def test_refresh_dead_worker():
Expand All @@ -151,26 +145,3 @@ def test_refresh_dead_worker():

assert cred.token == request
assert cred.refresh_count == 1


def test_worker_has_error():
w = _refresh_worker.RefreshThreadManager()
w._worker = mock.MagicMock()
w._worker._error_info = "Something"

assert w.has_error()


def test_no_worker_has_error():
w = _refresh_worker.RefreshThreadManager()
w._worker = None

assert not w.has_error()


def test_worker_has_error_no_error():
w = _refresh_worker.RefreshThreadManager()
w._worker = mock.MagicMock()
w._worker._error_info = None

assert not w.has_error()
51 changes: 32 additions & 19 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from google.auth import _helpers
from google.auth import credentials
from google.auth import exceptions


class CredentialsImpl(credentials.Credentials):
Expand Down Expand Up @@ -236,24 +235,6 @@ def test_create_scoped_if_required_not_scopes():
assert scoped_credentials is unscoped_credentials


def test_get_background_refresh_error_no_error():
credentials = CredentialsImpl()
error = credentials.get_background_refresh_error()

assert error is None


def test_get_background_refresh_error():
credentials = CredentialsImpl()
credentials._refresh_worker._worker = mock.MagicMock
credentials._refresh_worker._worker._error_info = exceptions.RefreshError(
"sentinel"
)
error = credentials.get_background_refresh_error()

assert isinstance(error, exceptions.RefreshError)


def test_nonblocking_refresh_fresh_credentials():
c = CredentialsImpl()

Expand Down Expand Up @@ -315,6 +296,38 @@ def test_nonblocking_refresh_stale_credentials():
assert "x-identity-trust-boundary" not in headers


def test_nonblocking_refresh_failed_credentials():
c = CredentialsImpl()
c.with_non_blocking_refresh()

request = "token"
headers = {}

# Invalid credentials MUST require a blocking refresh.
c.before_request(request, "http://example.com", "GET", headers)
assert c.token_state == credentials.TokenState.FRESH
assert not c._refresh_worker._worker

c.expiry = (
datetime.datetime.utcnow()
+ _helpers.REFRESH_THRESHOLD
- datetime.timedelta(seconds=1)
)

# STALE credentials SHOULD spawn a non-blocking worker
assert c.token_state == credentials.TokenState.STALE
c._refresh_worker._worker = mock.MagicMock()
c._refresh_worker._worker._error_info = "Some Error"
c.before_request(request, "http://example.com", "GET", headers)
assert c._refresh_worker._worker is not None

assert c.token_state == credentials.TokenState.FRESH
assert c.valid
assert c.token == "token"
assert headers["authorization"] == "Bearer token"
assert "x-identity-trust-boundary" not in headers


def test_token_state_no_expiry():
c = CredentialsImpl()

Expand Down