Skip to content

Commit 4fcf83e

Browse files
lqiu96zhumin8
andauthoredAug 29, 2024··
feat: Support retrieving ID Token from IAM endpoint for ServiceAccountCredentials (#1433)
* feat: Support retrieving ID Token from IAM endpoint for ServiceAccountCredentials * chore: Update logic to use non-GDU IAM Endpoint URL * chore: Reformat the logic * chore: Refactor to use MockIAMCredentialsServiceTransportFactory * chore: Update tests * chore: Update tests * Update oauth2_http/java/com/google/auth/oauth2/ServiceAccountCredentials.java Co-authored-by: Min Zhu <zhumin@google.com> * chore: Fix lint issues * chore: Fix lint issues * chore: Address PR comments * chore: Add comment for initializing iamIdTokenURI in method --------- Co-authored-by: Min Zhu <zhumin@google.com>
1 parent 2418691 commit 4fcf83e

8 files changed

+455
-238
lines changed
 

‎oauth2_http/java/com/google/auth/oauth2/OAuth2Utils.java

+4
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ class OAuth2Utils {
7575
static final String TOKEN_TYPE_TOKEN_EXCHANGE = "urn:ietf:params:oauth:token-type:token-exchange";
7676
static final String GRANT_TYPE_JWT_BEARER = "urn:ietf:params:oauth:grant-type:jwt-bearer";
7777

78+
// generateIdToken endpoint is to be formatted with universe domain and client email
79+
static final String IAM_ID_TOKEN_ENDPOINT_FORMAT =
80+
"https://iamcredentials.%s/v1/projects/-/serviceAccounts/%s:generateIdToken";
81+
7882
static final URI TOKEN_SERVER_URI = URI.create("https://oauth2.googleapis.com/token");
7983
static final URI TOKEN_REVOKE_URI = URI.create("https://oauth2.googleapis.com/revoke");
8084
static final URI USER_AUTH_URI = URI.create("https://accounts.google.com/o/oauth2/auth");

‎oauth2_http/java/com/google/auth/oauth2/ServiceAccountCredentials.java

+93-23
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import com.google.api.client.http.GenericUrl;
3737
import com.google.api.client.http.HttpBackOffIOExceptionHandler;
3838
import com.google.api.client.http.HttpBackOffUnsuccessfulResponseHandler;
39+
import com.google.api.client.http.HttpContent;
40+
import com.google.api.client.http.HttpHeaders;
3941
import com.google.api.client.http.HttpRequest;
4042
import com.google.api.client.http.HttpRequestFactory;
4143
import com.google.api.client.http.HttpResponse;
@@ -52,9 +54,12 @@
5254
import com.google.auth.Credentials;
5355
import com.google.auth.RequestMetadataCallback;
5456
import com.google.auth.ServiceAccountSigner;
57+
import com.google.auth.http.AuthHttpConstants;
5558
import com.google.auth.http.HttpTransportFactory;
5659
import com.google.common.annotations.VisibleForTesting;
5760
import com.google.common.base.MoreObjects.ToStringHelper;
61+
import com.google.common.collect.ImmutableList;
62+
import com.google.common.collect.ImmutableMap;
5863
import com.google.common.collect.ImmutableSet;
5964
import com.google.errorprone.annotations.CanIgnoreReturnValue;
6065
import java.io.IOException;
@@ -547,7 +552,9 @@ public AccessToken refreshAccessToken() throws IOException {
547552
}
548553

549554
/**
550-
* Returns a Google ID Token from the metadata server on ComputeEngine.
555+
* Returns a Google ID Token from either the Oauth or IAM Endpoint. For Credentials that are in
556+
* the Google Default Universe (googleapis.com), the ID Token will be retrieved from the Oauth
557+
* Endpoint. Otherwise, it will be retrieved from the IAM Endpoint.
551558
*
552559
* @param targetAudience the aud: field the IdToken should include.
553560
* @param options list of Credential specific options for the token. Currently, unused for
@@ -558,21 +565,90 @@ public AccessToken refreshAccessToken() throws IOException {
558565
@Override
559566
public IdToken idTokenWithAudience(String targetAudience, List<Option> options)
560567
throws IOException {
568+
return isDefaultUniverseDomain()
569+
? getIdTokenOauthEndpoint(targetAudience)
570+
: getIdTokenIamEndpoint(targetAudience);
571+
}
561572

562-
JsonFactory jsonFactory = OAuth2Utils.JSON_FACTORY;
573+
/**
574+
* Uses the Oauth Endpoint to generate an ID token. Assertions and grant_type are sent in the
575+
* request body.
576+
*/
577+
private IdToken getIdTokenOauthEndpoint(String targetAudience) throws IOException {
563578
long currentTime = clock.currentTimeMillis();
564579
String assertion =
565-
createAssertionForIdToken(
566-
jsonFactory, currentTime, tokenServerUri.toString(), targetAudience);
580+
createAssertionForIdToken(currentTime, tokenServerUri.toString(), targetAudience);
567581

582+
Map<String, Object> requestParams =
583+
ImmutableMap.of("grant_type", GRANT_TYPE, "assertion", assertion);
568584
GenericData tokenRequest = new GenericData();
569-
tokenRequest.set("grant_type", GRANT_TYPE);
570-
tokenRequest.set("assertion", assertion);
585+
requestParams.forEach(tokenRequest::set);
586+
UrlEncodedContent content = new UrlEncodedContent(tokenRequest);
587+
588+
HttpRequest request = buildIdTokenRequest(tokenServerUri, transportFactory, content);
589+
HttpResponse httpResponse = executeRequest(request);
590+
591+
GenericData responseData = httpResponse.parseAs(GenericData.class);
592+
String rawToken = OAuth2Utils.validateString(responseData, "id_token", PARSE_ERROR_PREFIX);
593+
return IdToken.create(rawToken);
594+
}
595+
596+
/**
597+
* Use IAM generateIdToken endpoint to obtain an ID token.
598+
*
599+
* <p>This flow works as follows:
600+
*
601+
* <ol>
602+
* <li>Create a self-signed jwt with `https://www.googleapis.com/auth/iam` as the scope.
603+
* <li>Use the self-signed jwt as the access token, and make a POST request to IAM
604+
* generateIdToken endpoint.
605+
* <li>If the request is successfully, it will return {"token":"the ID token"}. Extract the ID
606+
* token.
607+
* </ol>
608+
*/
609+
private IdToken getIdTokenIamEndpoint(String targetAudience) throws IOException {
610+
JwtCredentials selfSignedJwtCredentials =
611+
createSelfSignedJwtCredentials(
612+
null, ImmutableList.of("https://www.googleapis.com/auth/iam"));
613+
Map<String, List<String>> responseMetadata = selfSignedJwtCredentials.getRequestMetadata(null);
614+
// JwtCredentials will return a map with one entry ("Authorization" -> List with size 1)
615+
String accessToken = responseMetadata.get(AuthHttpConstants.AUTHORIZATION).get(0);
616+
617+
// Do not check user options. These params are always set regardless of options configured
618+
Map<String, Object> requestParams =
619+
ImmutableMap.of("audience", targetAudience, "includeEmail", "true", "useEmailAzp", "true");
620+
GenericData tokenRequest = new GenericData();
621+
requestParams.forEach(tokenRequest::set);
571622
UrlEncodedContent content = new UrlEncodedContent(tokenRequest);
572623

624+
// Create IAM Token URI in this method instead of in the constructor because
625+
// `getUniverseDomain()` throws an IOException that would need to be caught
626+
URI iamIdTokenUri =
627+
URI.create(
628+
String.format(
629+
OAuth2Utils.IAM_ID_TOKEN_ENDPOINT_FORMAT, getUniverseDomain(), clientEmail));
630+
HttpRequest request = buildIdTokenRequest(iamIdTokenUri, transportFactory, content);
631+
// Use the Access Token from the SSJWT to request the ID Token from IAM Endpoint
632+
request.setHeaders(new HttpHeaders().set(AuthHttpConstants.AUTHORIZATION, accessToken));
633+
HttpResponse httpResponse = executeRequest(request);
634+
635+
GenericData responseData = httpResponse.parseAs(GenericData.class);
636+
// IAM Endpoint returns `token` instead of `id_token`
637+
String rawToken = OAuth2Utils.validateString(responseData, "token", PARSE_ERROR_PREFIX);
638+
return IdToken.create(rawToken);
639+
}
640+
641+
// Build a default POST HttpRequest to be used for both Oauth and IAM endpoints
642+
private HttpRequest buildIdTokenRequest(
643+
URI uri, HttpTransportFactory transportFactory, HttpContent content) throws IOException {
644+
JsonFactory jsonFactory = OAuth2Utils.JSON_FACTORY;
573645
HttpRequestFactory requestFactory = transportFactory.create().createRequestFactory();
574-
HttpRequest request = requestFactory.buildPostRequest(new GenericUrl(tokenServerUri), content);
646+
HttpRequest request = requestFactory.buildPostRequest(new GenericUrl(uri), content);
575647
request.setParser(new JsonObjectParser(jsonFactory));
648+
return request;
649+
}
650+
651+
private HttpResponse executeRequest(HttpRequest request) throws IOException {
576652
HttpResponse response;
577653
try {
578654
response = request.execute();
@@ -583,11 +659,7 @@ public IdToken idTokenWithAudience(String targetAudience, List<Option> options)
583659
e.getMessage(), getIssuer()),
584660
e);
585661
}
586-
587-
GenericData responseData = response.parseAs(GenericData.class);
588-
String rawToken = OAuth2Utils.validateString(responseData, "id_token", PARSE_ERROR_PREFIX);
589-
590-
return IdToken.create(rawToken);
662+
return response;
591663
}
592664

593665
/**
@@ -826,9 +898,9 @@ String createAssertion(JsonFactory jsonFactory, long currentTime) throws IOExcep
826898
}
827899

828900
@VisibleForTesting
829-
String createAssertionForIdToken(
830-
JsonFactory jsonFactory, long currentTime, String audience, String targetAudience)
901+
String createAssertionForIdToken(long currentTime, String audience, String targetAudience)
831902
throws IOException {
903+
JsonFactory jsonFactory = OAuth2Utils.JSON_FACTORY;
832904
JsonWebSignature.Header header = new JsonWebSignature.Header();
833905
header.setAlgorithm("RS256");
834906
header.setType("JWT");
@@ -849,9 +921,7 @@ String createAssertionForIdToken(
849921
try {
850922
payload.set("target_audience", targetAudience);
851923

852-
String assertion =
853-
JsonWebSignature.signUsingRsaSha256(privateKey, jsonFactory, header, payload);
854-
return assertion;
924+
return JsonWebSignature.signUsingRsaSha256(privateKey, jsonFactory, header, payload);
855925
} catch (GeneralSecurityException e) {
856926
throw new IOException(
857927
"Error signing service account access token request with private key.", e);
@@ -877,18 +947,18 @@ static URI getUriForSelfSignedJWT(URI uri) {
877947

878948
@VisibleForTesting
879949
JwtCredentials createSelfSignedJwtCredentials(final URI uri) {
950+
return createSelfSignedJwtCredentials(uri, scopes.isEmpty() ? defaultScopes : scopes);
951+
}
952+
953+
@VisibleForTesting
954+
JwtCredentials createSelfSignedJwtCredentials(final URI uri, Collection<String> scopes) {
880955
// Create a JwtCredentials for self-signed JWT. See https://google.aip.dev/auth/4111.
881956
JwtClaims.Builder claimsBuilder =
882957
JwtClaims.newBuilder().setIssuer(clientEmail).setSubject(clientEmail);
883958

884959
if (uri == null) {
885960
// If uri is null, use scopes.
886-
String scopeClaim = "";
887-
if (!scopes.isEmpty()) {
888-
scopeClaim = Joiner.on(' ').join(scopes);
889-
} else {
890-
scopeClaim = Joiner.on(' ').join(defaultScopes);
891-
}
961+
String scopeClaim = Joiner.on(' ').join(scopes);
892962
claimsBuilder.setAdditionalClaims(Collections.singletonMap("scope", scopeClaim));
893963
} else {
894964
// otherwise, use audience with the uri.

‎oauth2_http/javatests/com/google/auth/oauth2/GoogleCredentialsTest.java

+22-15
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import com.google.auth.http.HttpTransportFactory;
4242
import com.google.auth.oauth2.ExternalAccountAuthorizedUserCredentialsTest.MockExternalAccountAuthorizedUserCredentialsTransportFactory;
4343
import com.google.auth.oauth2.IdentityPoolCredentialsTest.MockExternalAccountCredentialsTransportFactory;
44-
import com.google.auth.oauth2.ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory;
4544
import com.google.common.collect.ImmutableList;
4645
import java.io.ByteArrayInputStream;
4746
import java.io.IOException;
@@ -604,13 +603,17 @@ public void fromStream_Impersonation_providesToken_WithQuotaProject() throws IOE
604603

605604
MockIAMCredentialsServiceTransportFactory transportFactory =
606605
new MockIAMCredentialsServiceTransportFactory();
607-
transportFactory.transport.setTargetPrincipal(
608-
ImpersonatedCredentialsTest.IMPERSONATED_CLIENT_EMAIL);
609-
transportFactory.transport.setAccessToken(ImpersonatedCredentialsTest.ACCESS_TOKEN);
610-
transportFactory.transport.setExpireTime(ImpersonatedCredentialsTest.getDefaultExpireTime());
611-
transportFactory.transport.setAccessTokenEndpoint(
612-
ImpersonatedCredentialsTest.IMPERSONATION_URL);
613-
transportFactory.transport.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
606+
transportFactory
607+
.getTransport()
608+
.setTargetPrincipal(ImpersonatedCredentialsTest.IMPERSONATED_CLIENT_EMAIL);
609+
transportFactory.getTransport().setAccessToken(ImpersonatedCredentialsTest.ACCESS_TOKEN);
610+
transportFactory
611+
.getTransport()
612+
.setExpireTime(ImpersonatedCredentialsTest.getDefaultExpireTime());
613+
transportFactory
614+
.getTransport()
615+
.setAccessTokenEndpoint(ImpersonatedCredentialsTest.IMPERSONATION_URL);
616+
transportFactory.getTransport().addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
614617

615618
InputStream impersonationCredentialsStream =
616619
ImpersonatedCredentialsTest.writeImpersonationCredentialsStream(
@@ -665,13 +668,17 @@ public void fromStream_Impersonation_providesToken_WithoutQuotaProject() throws
665668

666669
MockIAMCredentialsServiceTransportFactory transportFactory =
667670
new MockIAMCredentialsServiceTransportFactory();
668-
transportFactory.transport.setTargetPrincipal(
669-
ImpersonatedCredentialsTest.IMPERSONATED_CLIENT_EMAIL);
670-
transportFactory.transport.setAccessToken(ImpersonatedCredentialsTest.ACCESS_TOKEN);
671-
transportFactory.transport.setExpireTime(ImpersonatedCredentialsTest.getDefaultExpireTime());
672-
transportFactory.transport.setAccessTokenEndpoint(
673-
ImpersonatedCredentialsTest.IMPERSONATION_URL);
674-
transportFactory.transport.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
671+
transportFactory
672+
.getTransport()
673+
.setTargetPrincipal(ImpersonatedCredentialsTest.IMPERSONATED_CLIENT_EMAIL);
674+
transportFactory.getTransport().setAccessToken(ImpersonatedCredentialsTest.ACCESS_TOKEN);
675+
transportFactory
676+
.getTransport()
677+
.setExpireTime(ImpersonatedCredentialsTest.getDefaultExpireTime());
678+
transportFactory
679+
.getTransport()
680+
.setAccessTokenEndpoint(ImpersonatedCredentialsTest.IMPERSONATION_URL);
681+
transportFactory.getTransport().addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
675682

676683
InputStream impersonationCredentialsStream =
677684
ImpersonatedCredentialsTest.writeImpersonationCredentialsStream(

‎oauth2_http/javatests/com/google/auth/oauth2/IamUtilsTest.java

+65-54
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,22 @@ public void setup() throws IOException {
6666
public void sign_success_noRetry() {
6767
byte[] expectedSignature = {0xD, 0xE, 0xA, 0xD};
6868

69-
ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory transportFactory =
70-
new ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory();
71-
transportFactory.transport.setSignedBlob(expectedSignature);
72-
transportFactory.transport.setTargetPrincipal(CLIENT_EMAIL);
73-
transportFactory.transport.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
69+
MockIAMCredentialsServiceTransportFactory transportFactory =
70+
new MockIAMCredentialsServiceTransportFactory();
71+
transportFactory.getTransport().setSignedBlob(expectedSignature);
72+
transportFactory.getTransport().setTargetPrincipal(CLIENT_EMAIL);
73+
transportFactory.getTransport().addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
7474

7575
byte[] signature =
7676
IamUtils.sign(
7777
CLIENT_EMAIL,
7878
credentials,
79-
transportFactory.transport,
79+
transportFactory.getTransport(),
8080
expectedSignature,
8181
ImmutableMap.of());
8282
assertArrayEquals(expectedSignature, signature);
8383

84-
assertEquals(1, transportFactory.transport.getNumRequests());
84+
assertEquals(1, transportFactory.getTransport().getNumRequests());
8585
}
8686

8787
// The SignBlob RPC will retry up to three times before it gives up. This test will return two
@@ -91,28 +91,30 @@ public void sign_success_noRetry() {
9191
public void sign_retryTwoTimes_success() {
9292
byte[] expectedSignature = {0xD, 0xE, 0xA, 0xD};
9393

94-
ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory transportFactory =
95-
new ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory();
96-
transportFactory.transport.addStatusCodeAndMessage(
97-
HttpStatusCodes.STATUS_CODE_BAD_GATEWAY, "Bad Gateway");
98-
transportFactory.transport.addStatusCodeAndMessage(
99-
HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE, "Unavailable");
100-
transportFactory.transport.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
101-
transportFactory.transport.setSignedBlob(expectedSignature);
102-
transportFactory.transport.setTargetPrincipal(CLIENT_EMAIL);
94+
MockIAMCredentialsServiceTransportFactory transportFactory =
95+
new MockIAMCredentialsServiceTransportFactory();
96+
transportFactory
97+
.getTransport()
98+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_BAD_GATEWAY, "Bad Gateway");
99+
transportFactory
100+
.getTransport()
101+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE, "Unavailable");
102+
transportFactory.getTransport().addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
103+
transportFactory.getTransport().setSignedBlob(expectedSignature);
104+
transportFactory.getTransport().setTargetPrincipal(CLIENT_EMAIL);
103105

104106
byte[] signature =
105107
IamUtils.sign(
106108
CLIENT_EMAIL,
107109
credentials,
108-
transportFactory.transport,
110+
transportFactory.getTransport(),
109111
expectedSignature,
110112
ImmutableMap.of());
111113
assertArrayEquals(expectedSignature, signature);
112114

113115
// Expect that three requests are made (2 failures which are retries + 1 final requests which
114116
// resulted in a successful response)
115-
assertEquals(3, transportFactory.transport.getNumRequests());
117+
assertEquals(3, transportFactory.getTransport().getNumRequests());
116118
}
117119

118120
// The rpc will retry up to three times before it gives up. This test will enqueue three failed
@@ -122,30 +124,33 @@ public void sign_retryTwoTimes_success() {
122124
public void sign_retryThreeTimes_success() {
123125
byte[] expectedSignature = {0xD, 0xE, 0xA, 0xD};
124126

125-
ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory transportFactory =
126-
new ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory();
127-
transportFactory.transport.setSignedBlob(expectedSignature);
128-
transportFactory.transport.setTargetPrincipal(CLIENT_EMAIL);
129-
transportFactory.transport.addStatusCodeAndMessage(
130-
HttpStatusCodes.STATUS_CODE_BAD_GATEWAY, "Bad Gateway");
131-
transportFactory.transport.addStatusCodeAndMessage(
132-
HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE, "Unavailable");
133-
transportFactory.transport.addStatusCodeAndMessage(
134-
HttpStatusCodes.STATUS_CODE_SERVER_ERROR, "Server Error");
135-
transportFactory.transport.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
127+
MockIAMCredentialsServiceTransportFactory transportFactory =
128+
new MockIAMCredentialsServiceTransportFactory();
129+
transportFactory.getTransport().setSignedBlob(expectedSignature);
130+
transportFactory.getTransport().setTargetPrincipal(CLIENT_EMAIL);
131+
transportFactory
132+
.getTransport()
133+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_BAD_GATEWAY, "Bad Gateway");
134+
transportFactory
135+
.getTransport()
136+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE, "Unavailable");
137+
transportFactory
138+
.getTransport()
139+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_SERVER_ERROR, "Server Error");
140+
transportFactory.getTransport().addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
136141

137142
byte[] signature =
138143
IamUtils.sign(
139144
CLIENT_EMAIL,
140145
credentials,
141-
transportFactory.transport,
146+
transportFactory.getTransport(),
142147
expectedSignature,
143148
ImmutableMap.of());
144149
assertArrayEquals(expectedSignature, signature);
145150

146151
// Expect that three requests are made (3 failures which are retried + 1 final request which
147152
// resulted the final success response)
148-
assertEquals(4, transportFactory.transport.getNumRequests());
153+
assertEquals(4, transportFactory.getTransport().getNumRequests());
149154
}
150155

151156
// The rpc will retry up to three times before it gives up. This test will enqueue four failed
@@ -155,19 +160,23 @@ public void sign_retryThreeTimes_success() {
155160
public void sign_retryThreeTimes_exception() {
156161
byte[] expectedSignature = {0xD, 0xE, 0xA, 0xD};
157162

158-
ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory transportFactory =
159-
new ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory();
160-
transportFactory.transport.setSignedBlob(expectedSignature);
161-
transportFactory.transport.setTargetPrincipal(CLIENT_EMAIL);
162-
transportFactory.transport.addStatusCodeAndMessage(
163-
HttpStatusCodes.STATUS_CODE_BAD_GATEWAY, "Bad Gateway");
164-
transportFactory.transport.addStatusCodeAndMessage(
165-
HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE, "Unavailable");
166-
transportFactory.transport.addStatusCodeAndMessage(
167-
HttpStatusCodes.STATUS_CODE_SERVER_ERROR, "Server Error");
168-
transportFactory.transport.addStatusCodeAndMessage(
169-
HttpStatusCodes.STATUS_CODE_BAD_GATEWAY, "Bad Gateway");
170-
transportFactory.transport.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
163+
MockIAMCredentialsServiceTransportFactory transportFactory =
164+
new MockIAMCredentialsServiceTransportFactory();
165+
transportFactory.getTransport().setSignedBlob(expectedSignature);
166+
transportFactory.getTransport().setTargetPrincipal(CLIENT_EMAIL);
167+
transportFactory
168+
.getTransport()
169+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_BAD_GATEWAY, "Bad Gateway");
170+
transportFactory
171+
.getTransport()
172+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE, "Unavailable");
173+
transportFactory
174+
.getTransport()
175+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_SERVER_ERROR, "Server Error");
176+
transportFactory
177+
.getTransport()
178+
.addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_BAD_GATEWAY, "Bad Gateway");
179+
transportFactory.getTransport().addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
171180

172181
ServiceAccountSigner.SigningException exception =
173182
assertThrows(
@@ -176,7 +185,7 @@ public void sign_retryThreeTimes_exception() {
176185
IamUtils.sign(
177186
CLIENT_EMAIL,
178187
credentials,
179-
transportFactory.transport,
188+
transportFactory.getTransport(),
180189
expectedSignature,
181190
ImmutableMap.of()));
182191
assertTrue(exception.getMessage().contains("Failed to sign the provided bytes"));
@@ -188,19 +197,21 @@ public void sign_retryThreeTimes_exception() {
188197

189198
// Expect that three requests are made (3 failures which are retried + 1 final request which
190199
// resulted in another failed response)
191-
assertEquals(4, transportFactory.transport.getNumRequests());
200+
assertEquals(4, transportFactory.getTransport().getNumRequests());
192201
}
193202

194203
@Test
195204
public void sign_4xxError_noRetry_exception() {
196205
byte[] expectedSignature = {0xD, 0xE, 0xA, 0xD};
197206

198-
ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory transportFactory =
199-
new ImpersonatedCredentialsTest.MockIAMCredentialsServiceTransportFactory();
200-
transportFactory.transport.setSignedBlob(expectedSignature);
201-
transportFactory.transport.setTargetPrincipal(CLIENT_EMAIL);
202-
transportFactory.transport.addStatusCodeAndMessage(
203-
HttpStatusCodes.STATUS_CODE_UNAUTHORIZED, "Failed to sign the provided bytes");
207+
MockIAMCredentialsServiceTransportFactory transportFactory =
208+
new MockIAMCredentialsServiceTransportFactory();
209+
transportFactory.getTransport().setSignedBlob(expectedSignature);
210+
transportFactory.getTransport().setTargetPrincipal(CLIENT_EMAIL);
211+
transportFactory
212+
.getTransport()
213+
.addStatusCodeAndMessage(
214+
HttpStatusCodes.STATUS_CODE_UNAUTHORIZED, "Failed to sign the provided bytes");
204215

205216
ServiceAccountSigner.SigningException exception =
206217
assertThrows(
@@ -209,7 +220,7 @@ public void sign_4xxError_noRetry_exception() {
209220
IamUtils.sign(
210221
CLIENT_EMAIL,
211222
credentials,
212-
transportFactory.transport,
223+
transportFactory.getTransport(),
213224
expectedSignature,
214225
ImmutableMap.of()));
215226
assertTrue(exception.getMessage().contains("Failed to sign the provided bytes"));
@@ -220,6 +231,6 @@ public void sign_4xxError_noRetry_exception() {
220231
.contains("Error code 401 trying to sign provided bytes:"));
221232

222233
// Only one request will have been made for a 4xx error (no retries)
223-
assertEquals(1, transportFactory.transport.getNumRequests());
234+
assertEquals(1, transportFactory.getTransport().getNumRequests());
224235
}
225236
}

‎oauth2_http/javatests/com/google/auth/oauth2/ImpersonatedCredentialsTest.java

+120-123
Large diffs are not rendered by default.

‎oauth2_http/javatests/com/google/auth/oauth2/MockIAMCredentialsServiceTransport.java

+17-7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
package com.google.auth.oauth2;
3333

34+
import static com.google.auth.oauth2.OAuth2Utils.IAM_ID_TOKEN_ENDPOINT_FORMAT;
35+
3436
import com.google.api.client.http.HttpStatusCodes;
3537
import com.google.api.client.http.LowLevelHttpRequest;
3638
import com.google.api.client.http.LowLevelHttpResponse;
@@ -61,11 +63,9 @@ public ServerResponse(int statusCode, String response, boolean repeatServerRespo
6163
}
6264

6365
private static final String DEFAULT_IAM_ACCESS_TOKEN_ENDPOINT =
64-
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/%s:generateAccessToken";
65-
private static final String IAM_ID_TOKEN_ENDPOINT =
66-
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/%s:generateIdToken";
66+
"https://iamcredentials.%s/v1/projects/-/serviceAccounts/%s:generateAccessToken";
6767
private static final String IAM_SIGN_ENDPOINT =
68-
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/%s:signBlob";
68+
"https://iamcredentials.%s/v1/projects/-/serviceAccounts/%s:signBlob";
6969

7070
private final Deque<ServerResponse> serverResponses;
7171

@@ -78,8 +78,15 @@ public ServerResponse(int statusCode, String response, boolean repeatServerRespo
7878

7979
private String idToken;
8080

81+
private String universeDomain;
82+
8183
private MockLowLevelHttpRequest request;
8284

85+
MockIAMCredentialsServiceTransport(String universeDomain) {
86+
this.universeDomain = universeDomain;
87+
this.serverResponses = new ArrayDeque<>();
88+
}
89+
8390
// Store the number of requests that are sent to the Mock Server. This is used to track the
8491
// number of retries attempts made to ensure that retry boundaries are respected.
8592
private int numRequests;
@@ -138,9 +145,12 @@ public LowLevelHttpRequest buildRequest(String method, String url) throws IOExce
138145
String iamAccessTokenFormattedUrl =
139146
iamAccessTokenEndpoint != null
140147
? iamAccessTokenEndpoint
141-
: String.format(DEFAULT_IAM_ACCESS_TOKEN_ENDPOINT, this.targetPrincipal);
142-
String iamSignBlobformattedUrl = String.format(IAM_SIGN_ENDPOINT, this.targetPrincipal);
143-
String iamIdTokenformattedUrl = String.format(IAM_ID_TOKEN_ENDPOINT, this.targetPrincipal);
148+
: String.format(
149+
DEFAULT_IAM_ACCESS_TOKEN_ENDPOINT, universeDomain, this.targetPrincipal);
150+
String iamSignBlobformattedUrl =
151+
String.format(IAM_SIGN_ENDPOINT, universeDomain, this.targetPrincipal);
152+
String iamIdTokenformattedUrl =
153+
String.format(IAM_ID_TOKEN_ENDPOINT_FORMAT, universeDomain, this.targetPrincipal);
144154
ServerResponse serverResponse = serverResponses.poll();
145155
// Status code was configured to be repeated until connection is terminated
146156
if (serverResponse.repeatServerResponse) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2024, Google Inc. All rights reserved.
3+
*
4+
* Redistribution and use in source and binary forms, with or without
5+
* modification, are permitted provided that the following conditions are
6+
* met:
7+
*
8+
* * Redistributions of source code must retain the above copyright
9+
* notice, this list of conditions and the following disclaimer.
10+
* * Redistributions in binary form must reproduce the above
11+
* copyright notice, this list of conditions and the following disclaimer
12+
* in the documentation and/or other materials provided with the
13+
* distribution.
14+
*
15+
* * Neither the name of Google Inc. nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20+
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21+
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22+
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23+
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24+
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25+
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26+
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27+
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30+
*/
31+
package com.google.auth.oauth2;
32+
33+
import com.google.api.client.http.HttpTransport;
34+
import com.google.auth.Credentials;
35+
import com.google.auth.http.HttpTransportFactory;
36+
37+
public class MockIAMCredentialsServiceTransportFactory implements HttpTransportFactory {
38+
private MockIAMCredentialsServiceTransport transport;
39+
40+
MockIAMCredentialsServiceTransportFactory() {
41+
this(Credentials.GOOGLE_DEFAULT_UNIVERSE);
42+
}
43+
44+
MockIAMCredentialsServiceTransportFactory(String universeDomain) {
45+
this.transport = new MockIAMCredentialsServiceTransport(universeDomain);
46+
}
47+
48+
public MockIAMCredentialsServiceTransport getTransport() {
49+
return transport;
50+
}
51+
52+
@Override
53+
public HttpTransport create() {
54+
return transport;
55+
}
56+
}

‎oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountCredentialsTest.java

+78-16
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import static org.junit.Assert.fail;
4343

4444
import com.google.api.client.http.HttpResponseException;
45+
import com.google.api.client.http.HttpStatusCodes;
4546
import com.google.api.client.json.GenericJson;
4647
import com.google.api.client.json.JsonFactory;
4748
import com.google.api.client.json.gson.GsonFactory;
@@ -123,7 +124,6 @@ public class ServiceAccountCredentialsTest extends BaseSerializationTest {
123124
private static final int DEFAULT_LIFETIME_IN_SECONDS = 3600;
124125
private static final int INVALID_LIFETIME = 43210;
125126
private static final String JWT_ACCESS_PREFIX = "Bearer ";
126-
private static final String GOOGLE_DEFAULT_UNIVERSE = "googleapis.com";
127127

128128
private ServiceAccountCredentials.Builder createDefaultBuilderWithToken(String accessToken)
129129
throws IOException {
@@ -210,7 +210,7 @@ public void createdScoped_clones() throws IOException {
210210
assertArrayEquals(newScopes.toArray(), newCredentials.getScopes().toArray());
211211
assertEquals(USER, newCredentials.getServiceAccountUser());
212212
assertEquals(PROJECT_ID, newCredentials.getProjectId());
213-
assertEquals(GOOGLE_DEFAULT_UNIVERSE, newCredentials.getUniverseDomain());
213+
assertEquals(Credentials.GOOGLE_DEFAULT_UNIVERSE, newCredentials.getUniverseDomain());
214214

215215
assertArrayEquals(
216216
SCOPES.toArray(), ((ServiceAccountCredentials) credentials).getScopes().toArray());
@@ -310,8 +310,7 @@ public void createAssertionForIdToken_correct() throws IOException {
310310
JsonFactory jsonFactory = OAuth2Utils.JSON_FACTORY;
311311
long currentTimeMillis = Clock.SYSTEM.currentTimeMillis();
312312
String assertion =
313-
credentials.createAssertionForIdToken(
314-
jsonFactory, currentTimeMillis, null, "https://foo.com/bar");
313+
credentials.createAssertionForIdToken(currentTimeMillis, null, "https://foo.com/bar");
315314

316315
JsonWebSignature signature = JsonWebSignature.parse(jsonFactory, assertion);
317316
JsonWebToken.Payload payload = signature.getPayload();
@@ -329,8 +328,7 @@ public void createAssertionForIdToken_custom_lifetime() throws IOException {
329328
JsonFactory jsonFactory = OAuth2Utils.JSON_FACTORY;
330329
long currentTimeMillis = Clock.SYSTEM.currentTimeMillis();
331330
String assertion =
332-
credentials.createAssertionForIdToken(
333-
jsonFactory, currentTimeMillis, null, "https://foo.com/bar");
331+
credentials.createAssertionForIdToken(currentTimeMillis, null, "https://foo.com/bar");
334332

335333
JsonWebSignature signature = JsonWebSignature.parse(jsonFactory, assertion);
336334
JsonWebToken.Payload payload = signature.getPayload();
@@ -353,8 +351,7 @@ public void createAssertionForIdToken_incorrect() throws IOException {
353351
JsonFactory jsonFactory = OAuth2Utils.JSON_FACTORY;
354352
long currentTimeMillis = Clock.SYSTEM.currentTimeMillis();
355353
String assertion =
356-
credentials.createAssertionForIdToken(
357-
jsonFactory, currentTimeMillis, null, "https://foo.com/bar");
354+
credentials.createAssertionForIdToken(currentTimeMillis, null, "https://foo.com/bar");
358355

359356
JsonWebSignature signature = JsonWebSignature.parse(jsonFactory, assertion);
360357
JsonWebToken.Payload payload = signature.getPayload();
@@ -383,7 +380,7 @@ public void createdScoped_withAud_noUniverse_jwtWithScopesDisabled_accessToken()
383380

384381
GoogleCredentials scopedCredentials = credentials.createScoped(SCOPES);
385382
assertEquals(false, credentials.isExplicitUniverseDomain());
386-
assertEquals(GOOGLE_DEFAULT_UNIVERSE, credentials.getUniverseDomain());
383+
assertEquals(Credentials.GOOGLE_DEFAULT_UNIVERSE, credentials.getUniverseDomain());
387384
Map<String, List<String>> metadata = scopedCredentials.getRequestMetadata(CALL_URI);
388385
TestUtils.assertContainsBearerToken(metadata, ACCESS_TOKEN);
389386
}
@@ -518,7 +515,7 @@ public void fromJSON_getProjectId() throws IOException {
518515
ServiceAccountCredentials credentials =
519516
ServiceAccountCredentials.fromJson(json, new MockTokenServerTransportFactory());
520517
assertEquals(PROJECT_ID, credentials.getProjectId());
521-
assertEquals(GOOGLE_DEFAULT_UNIVERSE, credentials.getUniverseDomain());
518+
assertEquals(Credentials.GOOGLE_DEFAULT_UNIVERSE, credentials.getUniverseDomain());
522519
}
523520

524521
@Test
@@ -623,7 +620,7 @@ public void getRequestMetadata_customTokenServer_hasAccessToken() throws IOExcep
623620
@Test
624621
public void getUniverseDomain_defaultUniverse() throws IOException {
625622
ServiceAccountCredentials credentials = createDefaultBuilder().build();
626-
assertEquals(GOOGLE_DEFAULT_UNIVERSE, credentials.getUniverseDomain());
623+
assertEquals(Credentials.GOOGLE_DEFAULT_UNIVERSE, credentials.getUniverseDomain());
627624
}
628625

629626
@Test
@@ -850,7 +847,7 @@ public void refreshAccessToken_4xx_5xx_NonRetryableFails() throws IOException {
850847
}
851848

852849
@Test
853-
public void idTokenWithAudience_correct() throws IOException {
850+
public void idTokenWithAudience_oauthFlow_targetAudienceMatchesAudClaim() throws IOException {
854851
String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
855852
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
856853
MockTokenServerTransport transport = transportFactory.transport;
@@ -869,13 +866,16 @@ public void idTokenWithAudience_correct() throws IOException {
869866
tokenCredential.refresh();
870867
assertEquals(DEFAULT_ID_TOKEN, tokenCredential.getAccessToken().getTokenValue());
871868
assertEquals(DEFAULT_ID_TOKEN, tokenCredential.getIdToken().getTokenValue());
869+
870+
// ID Token's aud claim is `https://foo.bar`
872871
assertEquals(
873872
targetAudience,
874-
(String) tokenCredential.getIdToken().getJsonWebSignature().getPayload().getAudience());
873+
tokenCredential.getIdToken().getJsonWebSignature().getPayload().getAudience());
875874
}
876875

877876
@Test
878-
public void idTokenWithAudience_incorrect() throws IOException {
877+
public void idTokenWithAudience_oauthFlow_targetAudienceDoesNotMatchAudClaim()
878+
throws IOException {
879879
String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
880880
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
881881
MockTokenServerTransport transport = transportFactory.transport;
@@ -885,16 +885,78 @@ public void idTokenWithAudience_incorrect() throws IOException {
885885
transport.addServiceAccount(CLIENT_EMAIL, accessToken1);
886886
TestUtils.assertContainsBearerToken(credentials.getRequestMetadata(CALL_URI), accessToken1);
887887

888-
String targetAudience = "https://bar";
888+
String targetAudience = "differentAudience";
889+
IdTokenCredentials tokenCredential =
890+
IdTokenCredentials.newBuilder()
891+
.setIdTokenProvider(credentials)
892+
.setTargetAudience(targetAudience)
893+
.build();
894+
tokenCredential.refresh();
895+
896+
// ID Token's aud claim is `https://foo.bar`
897+
assertNotEquals(
898+
targetAudience,
899+
tokenCredential.getIdToken().getJsonWebSignature().getPayload().getAudience());
900+
}
901+
902+
@Test
903+
public void idTokenWithAudience_iamFlow_targetAudienceMatchesAudClaim() throws IOException {
904+
String nonGDU = "test.com";
905+
MockIAMCredentialsServiceTransportFactory transportFactory =
906+
new MockIAMCredentialsServiceTransportFactory(nonGDU);
907+
transportFactory.getTransport().setTargetPrincipal(CLIENT_EMAIL);
908+
transportFactory.getTransport().setIdToken(DEFAULT_ID_TOKEN);
909+
transportFactory.getTransport().addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
910+
ServiceAccountCredentials credentials =
911+
createDefaultBuilder()
912+
.setScopes(SCOPES)
913+
.setHttpTransportFactory(transportFactory)
914+
.setUniverseDomain(nonGDU)
915+
.build();
916+
917+
String targetAudience = "https://foo.bar";
889918
IdTokenCredentials tokenCredential =
890919
IdTokenCredentials.newBuilder()
891920
.setIdTokenProvider(credentials)
892921
.setTargetAudience(targetAudience)
893922
.build();
894923
tokenCredential.refresh();
924+
assertEquals(DEFAULT_ID_TOKEN, tokenCredential.getAccessToken().getTokenValue());
925+
assertEquals(DEFAULT_ID_TOKEN, tokenCredential.getIdToken().getTokenValue());
926+
927+
// ID Token's aud claim is `https://foo.bar`
928+
assertEquals(
929+
targetAudience,
930+
tokenCredential.getIdToken().getJsonWebSignature().getPayload().getAudience());
931+
}
932+
933+
@Test
934+
public void idTokenWithAudience_iamFlow_targetAudienceDoesNotMatchAudClaim() throws IOException {
935+
String nonGDU = "test.com";
936+
MockIAMCredentialsServiceTransportFactory transportFactory =
937+
new MockIAMCredentialsServiceTransportFactory(nonGDU);
938+
transportFactory.getTransport().setTargetPrincipal(CLIENT_EMAIL);
939+
transportFactory.getTransport().setIdToken(DEFAULT_ID_TOKEN);
940+
transportFactory.getTransport().addStatusCodeAndMessage(HttpStatusCodes.STATUS_CODE_OK, "");
941+
ServiceAccountCredentials credentials =
942+
createDefaultBuilder()
943+
.setScopes(SCOPES)
944+
.setHttpTransportFactory(transportFactory)
945+
.setUniverseDomain(nonGDU)
946+
.build();
947+
948+
String targetAudience = "differentAudience";
949+
IdTokenCredentials tokenCredential =
950+
IdTokenCredentials.newBuilder()
951+
.setIdTokenProvider(credentials)
952+
.setTargetAudience(targetAudience)
953+
.build();
954+
tokenCredential.refresh();
955+
956+
// ID Token's aud claim is `https://foo.bar`
895957
assertNotEquals(
896958
targetAudience,
897-
(String) tokenCredential.getIdToken().getJsonWebSignature().getPayload().getAudience());
959+
tokenCredential.getIdToken().getJsonWebSignature().getPayload().getAudience());
898960
}
899961

900962
@Test

0 commit comments

Comments
 (0)
Please sign in to comment.