Skip to content

Commit

Permalink
feat: adding universe domain support for downscroped credentials (#1463)
Browse files Browse the repository at this point in the history
* feat: adding universe domain support for downscroped credentials

* fix lint

* address comments

* Update tests/test_downscoped.py

Co-authored-by: Leo <39062083+lsirac@users.noreply.github.com>

---------

Co-authored-by: Leo <39062083+lsirac@users.noreply.github.com>
  • Loading branch information
BigTailWolf and lsirac committed Feb 8, 2024
1 parent 2c3304b commit fa8b7b2
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 8 deletions.
4 changes: 3 additions & 1 deletion google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from google.auth import metrics
from google.auth._refresh_worker import RefreshThreadManager

DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"


class Credentials(metaclass=abc.ABCMeta):
"""Base class for all credentials.
Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(self):
"""Optional[dict]: Cache of a trust boundary response which has a list
of allowed regions and an encoded string representation of credentials
trust boundary."""
self._universe_domain = "googleapis.com"
self._universe_domain = DEFAULT_UNIVERSE_DOMAIN
"""Optional[str]: The universe domain value, default is googleapis.com
"""

Expand Down
14 changes: 11 additions & 3 deletions google/auth/downscoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
# The token exchange requested_token_type. This is always an access_token.
_STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
# The STS token URL used to exchanged a short lived access token for a downscoped one.
_STS_TOKEN_URL = "https://sts.googleapis.com/v1/token"
_STS_TOKEN_URL_PATTERN = "https://sts.{}/v1/token"
# The subject token type to use when exchanging a short lived access token for a
# downscoped token.
_STS_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
Expand Down Expand Up @@ -437,7 +437,11 @@ class Credentials(credentials.CredentialsWithQuotaProject):
"""

def __init__(
self, source_credentials, credential_access_boundary, quota_project_id=None
self,
source_credentials,
credential_access_boundary,
quota_project_id=None,
universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN,
):
"""Instantiates a downscoped credentials object using the provided source
credentials and credential access boundary rules.
Expand All @@ -456,6 +460,7 @@ def __init__(
the upper bound of the permissions that are available on that resource and an
optional condition to further restrict permissions.
quota_project_id (Optional[str]): The optional quota project ID.
universe_domain (Optional[str]): The universe domain value, default is googleapis.com
Raises:
google.auth.exceptions.RefreshError: If the source credentials
return an error on token refresh.
Expand All @@ -467,7 +472,10 @@ def __init__(
self._source_credentials = source_credentials
self._credential_access_boundary = credential_access_boundary
self._quota_project_id = quota_project_id
self._sts_client = sts.Client(_STS_TOKEN_URL)
self._universe_domain = universe_domain or credentials.DEFAULT_UNIVERSE_DOMAIN
self._sts_client = sts.Client(
_STS_TOKEN_URL_PATTERN.format(self.universe_domain)
)

@_helpers.copy_docstring(credentials.Credentials)
def refresh(self, request):
Expand Down
88 changes: 84 additions & 4 deletions tests/test_downscoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.auth import downscoped
from google.auth import exceptions
from google.auth import transport
from google.auth.credentials import DEFAULT_UNIVERSE_DOMAIN
from google.auth.credentials import TokenState


Expand Down Expand Up @@ -447,7 +448,11 @@ def test_to_json(self):

class TestCredentials(object):
@staticmethod
def make_credentials(source_credentials=SourceCredentials(), quota_project_id=None):
def make_credentials(
source_credentials=SourceCredentials(),
quota_project_id=None,
universe_domain=None,
):
availability_condition = make_availability_condition(
EXPRESSION, TITLE, DESCRIPTION
)
Expand All @@ -458,7 +463,10 @@ def make_credentials(source_credentials=SourceCredentials(), quota_project_id=No
credential_access_boundary = make_credential_access_boundary(rules)

return downscoped.Credentials(
source_credentials, credential_access_boundary, quota_project_id
source_credentials,
credential_access_boundary,
quota_project_id,
universe_domain,
)

@staticmethod
Expand All @@ -473,10 +481,12 @@ def make_mock_request(data, status=http_client.OK):
return request

@staticmethod
def assert_request_kwargs(request_kwargs, headers, request_data):
def assert_request_kwargs(
request_kwargs, headers, request_data, token_endpoint=TOKEN_EXCHANGE_ENDPOINT
):
"""Asserts the request was called with the expected parameters.
"""
assert request_kwargs["url"] == TOKEN_EXCHANGE_ENDPOINT
assert request_kwargs["url"] == token_endpoint
assert request_kwargs["method"] == "POST"
assert request_kwargs["headers"] == headers
assert request_kwargs["body"] is not None
Expand All @@ -496,6 +506,33 @@ def test_default_state(self):
assert not credentials.expired
# No quota project ID set.
assert not credentials.quota_project_id
assert credentials.universe_domain == DEFAULT_UNIVERSE_DOMAIN

def test_default_state_with_explicit_none_value(self):
credentials = self.make_credentials(universe_domain=None)

# No token acquired yet.
assert not credentials.token
assert not credentials.valid
# Expiration hasn't been set yet.
assert not credentials.expiry
assert not credentials.expired
# No quota project ID set.
assert not credentials.quota_project_id
assert credentials.universe_domain == DEFAULT_UNIVERSE_DOMAIN

def test_create_with_customized_universe_domain(self):
test_universe_domain = "foo.com"
credentials = self.make_credentials(universe_domain=test_universe_domain)
# No token acquired yet.
assert not credentials.token
assert not credentials.valid
# Expiration hasn't been set yet.
assert not credentials.expiry
assert not credentials.expired
# No quota project ID set.
assert not credentials.quota_project_id
assert credentials.universe_domain == test_universe_domain

def test_with_quota_project(self):
credentials = self.make_credentials()
Expand All @@ -506,6 +543,49 @@ def test_with_quota_project(self):

assert quota_project_creds.quota_project_id == "project-foo"

@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
def test_refresh_on_custom_universe(self, unused_utcnow):
test_universe_domain = "foo.com"
response = SUCCESS_RESPONSE.copy()
# Test custom expiration to confirm expiry is set correctly.
response["expires_in"] = 2800
expected_expiry = datetime.datetime.min + datetime.timedelta(
seconds=response["expires_in"]
)
headers = {"Content-Type": "application/x-www-form-urlencoded"}
request_data = {
"grant_type": GRANT_TYPE,
"subject_token": "ACCESS_TOKEN_1",
"subject_token_type": SUBJECT_TOKEN_TYPE,
"requested_token_type": REQUESTED_TOKEN_TYPE,
"options": urllib.parse.quote(json.dumps(CREDENTIAL_ACCESS_BOUNDARY_JSON)),
}
request = self.make_mock_request(status=http_client.OK, data=response)
source_credentials = SourceCredentials()
credentials = self.make_credentials(
source_credentials=source_credentials, universe_domain=test_universe_domain
)
token_exchange_endpoint = downscoped._STS_TOKEN_URL_PATTERN.format(
test_universe_domain
)

# Spy on calls to source credentials refresh to confirm the expected request
# instance is used.
with mock.patch.object(
source_credentials, "refresh", wraps=source_credentials.refresh
) as wrapped_souce_cred_refresh:
credentials.refresh(request)

self.assert_request_kwargs(
request.call_args[1], headers, request_data, token_exchange_endpoint
)
assert credentials.valid
assert credentials.expiry == expected_expiry
assert not credentials.expired
assert credentials.token == response["access_token"]
# Confirm source credentials called with the same request instance.
wrapped_souce_cred_refresh.assert_called_with(request)

@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
def test_refresh(self, unused_utcnow):
response = SUCCESS_RESPONSE.copy()
Expand Down

0 comments on commit fa8b7b2

Please sign in to comment.