Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding universe domain support for downscroped credentials #1463

Merged
merged 5 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from google.auth import metrics
from google.auth._refresh_worker import RefreshThreadManager

DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
BigTailWolf marked this conversation as resolved.
Show resolved Hide resolved


class Credentials(metaclass=abc.ABCMeta):
"""Base class for all credentials.
Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(self):
"""Optional[dict]: Cache of a trust boundary response which has a list
of allowed regions and an encoded string representation of credentials
trust boundary."""
self._universe_domain = "googleapis.com"
self._universe_domain = DEFAULT_UNIVERSE_DOMAIN
"""Optional[str]: The universe domain value, default is googleapis.com
"""

Expand Down
14 changes: 11 additions & 3 deletions google/auth/downscoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
# The token exchange requested_token_type. This is always an access_token.
_STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
# The STS token URL used to exchanged a short lived access token for a downscoped one.
_STS_TOKEN_URL = "https://sts.googleapis.com/v1/token"
_STS_TOKEN_URL_PATTERN = "https://sts.{}/v1/token"
# The subject token type to use when exchanging a short lived access token for a
# downscoped token.
_STS_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
Expand Down Expand Up @@ -437,7 +437,11 @@ class Credentials(credentials.CredentialsWithQuotaProject):
"""

def __init__(
self, source_credentials, credential_access_boundary, quota_project_id=None
self,
source_credentials,
credential_access_boundary,
quota_project_id=None,
universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN,
BigTailWolf marked this conversation as resolved.
Show resolved Hide resolved
):
"""Instantiates a downscoped credentials object using the provided source
credentials and credential access boundary rules.
Expand All @@ -456,6 +460,7 @@ def __init__(
the upper bound of the permissions that are available on that resource and an
optional condition to further restrict permissions.
quota_project_id (Optional[str]): The optional quota project ID.
universe_domain (Optional[str]): The universe domain value, default is googleapis.com
Raises:
google.auth.exceptions.RefreshError: If the source credentials
return an error on token refresh.
Expand All @@ -467,7 +472,10 @@ def __init__(
self._source_credentials = source_credentials
self._credential_access_boundary = credential_access_boundary
self._quota_project_id = quota_project_id
self._sts_client = sts.Client(_STS_TOKEN_URL)
self._universe_domain = universe_domain or credentials.DEFAULT_UNIVERSE_DOMAIN
BigTailWolf marked this conversation as resolved.
Show resolved Hide resolved
self._sts_client = sts.Client(
_STS_TOKEN_URL_PATTERN.format(self.universe_domain)
)

@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
Expand Down
88 changes: 84 additions & 4 deletions tests/test_downscoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.auth import downscoped
from google.auth import exceptions
from google.auth import transport
from google.auth.credentials import DEFAULT_UNIVERSE_DOMAIN
from google.auth.credentials import TokenState


Expand Down Expand Up @@ -447,7 +448,11 @@ def test_to_json(self):

class TestCredentials(object):
@staticmethod
def make_credentials(source_credentials=SourceCredentials(), quota_project_id=None):
def make_credentials(
source_credentials=SourceCredentials(),
quota_project_id=None,
universe_domain=None,
):
availability_condition = make_availability_condition(
EXPRESSION, TITLE, DESCRIPTION
)
Expand All @@ -458,7 +463,10 @@ def make_credentials(source_credentials=SourceCredentials(), quota_project_id=No
credential_access_boundary = make_credential_access_boundary(rules)

return downscoped.Credentials(
source_credentials, credential_access_boundary, quota_project_id
source_credentials,
credential_access_boundary,
quota_project_id,
universe_domain,
)

@staticmethod
Expand All @@ -473,10 +481,12 @@ def make_mock_request(data, status=http_client.OK):
return request

@staticmethod
def assert_request_kwargs(request_kwargs, headers, request_data):
def assert_request_kwargs(
request_kwargs, headers, request_data, token_endpoint=TOKEN_EXCHANGE_ENDPOINT
):
"""Asserts the request was called with the expected parameters.
"""
assert request_kwargs["url"] == TOKEN_EXCHANGE_ENDPOINT
assert request_kwargs["url"] == token_endpoint
assert request_kwargs["method"] == "POST"
assert request_kwargs["headers"] == headers
assert request_kwargs["body"] is not None
Expand All @@ -496,6 +506,33 @@ def test_default_state(self):
assert not credentials.expired
# No quota project ID set.
assert not credentials.quota_project_id
assert credentials.universe_domain == DEFAULT_UNIVERSE_DOMAIN

def test_default_state_with_explicit_none_value(self):
credentials = self.make_credentials(universe_domain=None)

# No token acquired yet.
assert not credentials.token
assert not credentials.valid
# Expiration hasn't been set yet.
assert not credentials.expiry
assert not credentials.expired
# No quota project ID set.
assert not credentials.quota_project_id
assert credentials.universe_domain == DEFAULT_UNIVERSE_DOMAIN

def test_create_with_customized_universe_domain(self):
BigTailWolf marked this conversation as resolved.
Show resolved Hide resolved
test_universe_domain = "foo.com"
credentials = self.make_credentials(universe_domain=test_universe_domain)
# No token acquired yet.
assert not credentials.token
assert not credentials.valid
# Expiration hasn't been set yet.
assert not credentials.expiry
assert not credentials.expired
# No quota project ID set.
assert not credentials.quota_project_id
assert credentials.universe_domain == test_universe_domain

def test_with_quota_project(self):
credentials = self.make_credentials()
Expand All @@ -506,6 +543,49 @@ def test_with_quota_project(self):

assert quota_project_creds.quota_project_id == "project-foo"

@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
def test_refresh_on_custom_universe(self, unused_utcnow):
test_universe_domain = "foo.com"
response = SUCCESS_RESPONSE.copy()
# Test custom expiration to confirm expiry is set correctly.
response["expires_in"] = 2800
expected_expiry = datetime.datetime.min + datetime.timedelta(
seconds=response["expires_in"]
)
headers = {"Content-Type": "application/x-www-form-urlencoded"}
request_data = {
"grant_type": GRANT_TYPE,
"subject_token": "ACCESS_TOKEN_1",
"subject_token_type": SUBJECT_TOKEN_TYPE,
"requested_token_type": REQUESTED_TOKEN_TYPE,
"options": urllib.parse.quote(json.dumps(CREDENTIAL_ACCESS_BOUNDARY_JSON)),
}
request = self.make_mock_request(status=http_client.OK, data=response)
source_credentials = SourceCredentials()
credentials = self.make_credentials(
source_credentials=source_credentials, universe_domain=test_universe_domain
)
token_exchange_endpoint = downscoped._STS_TOKEN_URL_PATTERN.format(
test_universe_domain
)

# Spy on calls to source credentials refresh to confirm the expected request
# instance is used.
with mock.patch.object(
source_credentials, "refresh", wraps=source_credentials.refresh
) as wrapped_souce_cred_refresh:
credentials.refresh(request)

self.assert_request_kwargs(
request.call_args[1], headers, request_data, token_exchange_endpoint
)
assert credentials.valid
assert credentials.expiry == expected_expiry
assert not credentials.expired
assert credentials.token == response["access_token"]
# Confirm source credentials called with the same request instance.
wrapped_souce_cred_refresh.assert_called_with(request)

@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
def test_refresh(self, unused_utcnow):
response = SUCCESS_RESPONSE.copy()
Expand Down