Skip to content

Commit d76e892

Browse files
authoredJul 17, 2024··
feat: Retry API calls that return a 5xx error. Fixes #2029. (#2041)
This updates the retry logic for API calls made by the connector: - Attempts to get an authentication token from the Google auth API are always retried 5 times. - SQL Admin API requests are retried up to 5 times if the http response status code is a 5xx. Otherwise they are fatal on the first error. This also implements the exponential backoff logic defined used in the go connector. See: GoogleCloudPlatform/cloud-sql-go-connector#781 Fixes #2029
1 parent cec32c3 commit d76e892

9 files changed

+334
-132
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.sql.core;
18+
19+
import com.google.api.client.http.HttpResponseException;
20+
import java.util.concurrent.Callable;
21+
22+
/**
23+
* Extends RetryingCallable with logic to only retry on HTTP errors with error codes in the 5xx
24+
* range.
25+
*
26+
* @param <T> the return value for Callable
27+
*/
28+
class ApiClientRetryingCallable<T> extends RetryingCallable<T> {
29+
30+
/**
31+
* Construct a new RetryLogic.
32+
*
33+
* @param callable the callable that should be retried
34+
*/
35+
public ApiClientRetryingCallable(Callable<T> callable) {
36+
super(callable);
37+
}
38+
39+
/**
40+
* Returns false indicating that there should be another attempt if the exception is an HTTP
41+
* response with an error code in the 5xx range.
42+
*
43+
* @param e the exception
44+
* @return false if this is a http response with a 5xx status code, otherwise true.
45+
*/
46+
@Override
47+
protected boolean isFatalException(Exception e) {
48+
// Only retry if the error is an HTTP response with a 5xx error code.
49+
if (e instanceof HttpResponseException) {
50+
HttpResponseException re = (HttpResponseException) e;
51+
return re.getStatusCode() < 500;
52+
}
53+
// Otherwise this is a fatal exception, no more tries.
54+
return true;
55+
}
56+
}

‎core/src/main/java/com/google/cloud/sql/core/DefaultAccessTokenSupplier.java

+2-21
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import com.google.cloud.sql.AuthType;
2222
import com.google.cloud.sql.CredentialFactory;
2323
import java.io.IOException;
24-
import java.time.Duration;
2524
import java.time.Instant;
2625
import java.time.ZoneId;
2726
import java.time.format.DateTimeFormatter;
@@ -41,8 +40,6 @@ class DefaultAccessTokenSupplier implements AccessTokenSupplier {
4140
private static final String SQL_LOGIN_SCOPE = "https://www.googleapis.com/auth/sqlservice.login";
4241

4342
private final CredentialFactory credentialFactory;
44-
private final int retryCount;
45-
private final Duration retryDuration;
4643

4744
static AccessTokenSupplier newInstance(AuthType authType, CredentialFactory tokenSourceFactory) {
4845
if (authType == AuthType.IAM) {
@@ -52,27 +49,13 @@ static AccessTokenSupplier newInstance(AuthType authType, CredentialFactory toke
5249
}
5350
}
5451

55-
/**
56-
* Creates an instance with default retry settings.
57-
*
58-
* @param tokenSource the token source that produces auth tokens.
59-
*/
60-
DefaultAccessTokenSupplier(CredentialFactory tokenSource) {
61-
this(tokenSource, 3, Duration.ofSeconds(3));
62-
}
63-
6452
/**
6553
* Creates an instance with configurable retry settings.
6654
*
6755
* @param tokenSource the token source
68-
* @param retryCount the number of attempts to refresh.
69-
* @param retryDuration the duration to wait between attempts.
7056
*/
71-
DefaultAccessTokenSupplier(
72-
CredentialFactory tokenSource, int retryCount, Duration retryDuration) {
57+
DefaultAccessTokenSupplier(CredentialFactory tokenSource) {
7358
this.credentialFactory = tokenSource;
74-
this.retryCount = retryCount;
75-
this.retryDuration = retryDuration;
7659
}
7760

7861
/**
@@ -141,9 +124,7 @@ public Optional<AccessToken> get() throws IOException {
141124
}
142125

143126
return Optional.of(downscoped.getAccessToken());
144-
},
145-
this.retryCount,
146-
this.retryDuration);
127+
});
147128

148129
try {
149130
return retries.call();

‎core/src/main/java/com/google/cloud/sql/core/DefaultConnectionInfoRepository.java

+19-12
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ public ConnectionInfo getConnectionInfoSync(
112112
InstanceMetadata metadata = fetchMetadata(instanceName, authType);
113113
Certificate ephemeralCertificate =
114114
fetchEphemeralCertificate(keyPair, instanceName, token, authType);
115+
115116
SslData sslContext =
116117
createSslData(keyPair, metadata, ephemeralCertificate, instanceName, authType);
117118

@@ -228,10 +229,13 @@ String getQuotaProject(String connectionName) {
228229
private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthType authType) {
229230
try {
230231
ConnectSettings instanceMetadata =
231-
apiClient
232-
.connect()
233-
.get(instanceName.getProjectId(), instanceName.getInstanceId())
234-
.execute();
232+
new ApiClientRetryingCallable<>(
233+
() ->
234+
apiClient
235+
.connect()
236+
.get(instanceName.getProjectId(), instanceName.getInstanceId())
237+
.execute())
238+
.call();
235239

236240
// Validate the instance will support the authenticated connection.
237241
if (!instanceMetadata.getRegion().equals(instanceName.getRegionId())) {
@@ -293,7 +297,7 @@ private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthTy
293297
instanceName.getConnectionName()),
294298
ex);
295299
}
296-
} catch (IOException ex) {
300+
} catch (Exception ex) {
297301
throw addExceptionContext(
298302
ex,
299303
String.format(
@@ -326,12 +330,15 @@ private Certificate fetchEphemeralCertificate(
326330
GenerateEphemeralCertResponse response;
327331
try {
328332
response =
329-
apiClient
330-
.connect()
331-
.generateEphemeralCert(
332-
instanceName.getProjectId(), instanceName.getInstanceId(), request)
333-
.execute();
334-
} catch (IOException ex) {
333+
new ApiClientRetryingCallable<>(
334+
() ->
335+
apiClient
336+
.connect()
337+
.generateEphemeralCert(
338+
instanceName.getProjectId(), instanceName.getInstanceId(), request)
339+
.execute())
340+
.call();
341+
} catch (Exception ex) {
335342
throw addExceptionContext(
336343
ex,
337344
String.format(
@@ -425,7 +432,7 @@ private SslData createSslData(
425432
* provided to the user
426433
*/
427434
private RuntimeException addExceptionContext(
428-
IOException ex, String fallbackDesc, CloudSqlInstanceName instanceName) {
435+
Exception ex, String fallbackDesc, CloudSqlInstanceName instanceName) {
429436
String reason = fallbackDesc;
430437
int statusCode = 0;
431438

‎core/src/main/java/com/google/cloud/sql/core/RetryingCallable.java

+34-26
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package com.google.cloud.sql.core;
1818

19-
import java.time.Duration;
2019
import java.util.concurrent.Callable;
2120
import java.util.concurrent.ThreadLocalRandom;
2221

@@ -25,59 +24,57 @@
2524
* The sleep duration is chosen randomly in the range [sleepDuration, sleepDuration * 2] to avoid
2625
* causing a thundering herd of requests on failure.
2726
*
27+
* <p>exponentialBackoff calculates a duration based on the attempt i.
28+
*
29+
* <p>The formula is: base * multi^(attempt + 1 + random)
30+
*
31+
* <p>With base = 200ms and multi = 1.618, and random = [0.0, 1.0), the backoff values would fall
32+
* between the following low and high ends:
33+
*
34+
* <p>Attempt Low (ms) High (ms)
35+
*
36+
* <p>0 324 524 1 524 847 2 847 1371 3 1371 2218 4 2218 3588
37+
*
38+
* <p>The theoretical worst case scenario would have a client wait 8.5s in total for an API request
39+
* to complete (with the first four attempts failing, and the fifth succeeding).
40+
*
41+
* <p>This backoff strategy matches the behavior of the Cloud SQL Proxy v1.
42+
*
2843
* @param <T> the result type of the Callable.
2944
*/
3045
class RetryingCallable<T> implements Callable<T> {
46+
private static final int RETRY_COUNT = 5;
3147

3248
/** The callable that should be retried. */
3349
private final Callable<T> callable;
34-
/** The number of times to attempt to retry. */
35-
private final int retryCount;
36-
/** The duration to sleep after a failed retry attempt. */
37-
private final Duration sleepDuration;
3850

3951
/**
4052
* Construct a new RetryLogic.
4153
*
4254
* @param callable the callable that should be retried
43-
* @param retryCount the number of times to retry
44-
* @param sleepDuration the duration wait after a failed attempt.
4555
*/
46-
public RetryingCallable(Callable<T> callable, int retryCount, Duration sleepDuration) {
47-
if (retryCount <= 0) {
48-
throw new IllegalArgumentException("retryCount must be > 0");
49-
}
50-
if (sleepDuration.isNegative() || sleepDuration.isZero()) {
51-
throw new IllegalArgumentException("sleepDuration must be positive");
52-
}
56+
public RetryingCallable(Callable<T> callable) {
5357
if (callable == null) {
5458
throw new IllegalArgumentException("call must not be null");
5559
}
5660
this.callable = callable;
57-
this.retryCount = retryCount;
58-
this.sleepDuration = sleepDuration;
5961
}
6062

6163
@Override
6264
public T call() throws Exception {
6365

64-
for (int retriesLeft = retryCount - 1; retriesLeft >= 0; retriesLeft--) {
66+
for (int attempt = 0; attempt < RETRY_COUNT; attempt++) {
6567
// Attempt to call the Callable.
6668
try {
6769
return callable.call();
6870
} catch (Exception e) {
69-
// Callable threw an exception.
70-
71-
// If this is the last iteration, then
72-
// throw the exception
73-
if (retriesLeft == 0) {
71+
// If this is the last retry attempt, or if the exception is fatal
72+
// then exit immediately.
73+
if (attempt == (RETRY_COUNT - 1) || isFatalException(e)) {
7474
throw e;
7575
}
76-
7776
// Else, sleep a random amount of time, then retry
78-
long sleep =
79-
ThreadLocalRandom.current()
80-
.nextLong(sleepDuration.toMillis(), sleepDuration.toMillis() * 2);
77+
long sleep = exponentialBackoffMs(attempt);
8178
try {
8279
Thread.sleep(sleep);
8380
} catch (InterruptedException ie) {
@@ -90,4 +87,15 @@ public T call() throws Exception {
9087
// as long as the preconditions in the constructor are properly met.
9188
throw new RuntimeException("call was never called.");
9289
}
90+
91+
protected boolean isFatalException(Exception e) {
92+
return false;
93+
}
94+
95+
private long exponentialBackoffMs(int attempt) {
96+
long baseMs = 200;
97+
double multi = 1.618;
98+
double exp = attempt + 1.0 + ThreadLocalRandom.current().nextDouble();
99+
return (long) (baseMs * Math.pow(multi, exp));
100+
}
93101
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.sql.core;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
21+
import com.google.api.client.http.HttpHeaders;
22+
import com.google.api.client.http.HttpResponseException;
23+
import java.util.concurrent.atomic.AtomicInteger;
24+
import org.junit.Assert;
25+
import org.junit.Test;
26+
27+
public class ApiClientRetryingCallableTest {
28+
@Test
29+
public void testApiClientRetriesOn500ErrorAndSucceeds() throws Exception {
30+
AtomicInteger counter = new AtomicInteger(0);
31+
ApiClientRetryingCallable<Integer> c =
32+
new ApiClientRetryingCallable<>(
33+
() -> {
34+
int attempt = counter.incrementAndGet();
35+
if (attempt < 3) {
36+
throw new HttpResponseException.Builder(
37+
503, "service unavailable", new HttpHeaders())
38+
.build();
39+
}
40+
return attempt;
41+
});
42+
43+
Integer v = c.call();
44+
assertThat(counter.get()).isEqualTo(3);
45+
assertThat(v).isEqualTo(3);
46+
}
47+
48+
@Test
49+
public void testApiClientRetriesOn500ErrorAndFailsAfter5Attempts() throws Exception {
50+
AtomicInteger counter = new AtomicInteger(0);
51+
ApiClientRetryingCallable<Integer> c =
52+
new ApiClientRetryingCallable<>(
53+
() -> {
54+
counter.incrementAndGet();
55+
throw new HttpResponseException.Builder(503, "service unavailable", new HttpHeaders())
56+
.build();
57+
});
58+
59+
try {
60+
c.call();
61+
Assert.fail("got no exception, wants an exception to be thrown");
62+
} catch (Exception e) {
63+
// Expected to throw an exception
64+
}
65+
assertThat(counter.get()).isEqualTo(5);
66+
}
67+
68+
@Test
69+
public void testRetryStopsAfterFatalException() throws Exception {
70+
final AtomicInteger counter = new AtomicInteger();
71+
RetryingCallable<Integer> r =
72+
new RetryingCallable<Integer>(
73+
() -> {
74+
counter.incrementAndGet();
75+
throw new Exception("nope");
76+
}) {
77+
@Override
78+
protected boolean isFatalException(Exception e) {
79+
return true;
80+
}
81+
};
82+
83+
try {
84+
r.call();
85+
Assert.fail("got no exception, wants an exception to be thrown");
86+
} catch (Exception e) {
87+
// Expected to throw an exception
88+
}
89+
assertThat(counter.get()).isEqualTo(1);
90+
}
91+
}

‎core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java

+42-10
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import java.time.Duration;
4343
import java.util.Base64;
4444
import java.util.Collections;
45+
import java.util.concurrent.ConcurrentHashMap;
4546
import org.junit.Before;
4647

4748
public class CloudSqlCoreTestingBase {
@@ -65,7 +66,7 @@ public class CloudSqlCoreTestingBase {
6566
public CloudSqlCoreTestingBase() {}
6667

6768
// Creates a fake "accessNotConfigured" exception that can be used for testing.
68-
static HttpTransport fakeNotConfiguredException() {
69+
static MockHttpTransport fakeNotConfiguredException() {
6970
return fakeGoogleJsonResponseException(
7071
"accessNotConfigured",
7172
"Cloud SQL Admin API has not been used in project 12345 before or it is disabled. Enable"
@@ -77,27 +78,27 @@ static HttpTransport fakeNotConfiguredException() {
7778
}
7879

7980
// Creates a fake "notAuthorized" exception that can be used for testing.
80-
static HttpTransport fakeNotAuthorizedException() {
81+
static MockHttpTransport fakeNotAuthorizedException() {
8182
return fakeGoogleJsonResponseException(
8283
"notAuthorized", ERROR_MESSAGE_NOT_AUTHORIZED, HttpStatusCodes.STATUS_CODE_UNAUTHORIZED);
8384
}
8485

8586
// Creates a fake "serverError" exception that can be used for testing.
86-
static HttpTransport fakeBadGatewayException() {
87+
static MockHttpTransport fakeBadGatewayException() {
8788
return fakeGoogleJsonResponseException(
8889
"serverError", ERROR_MESSAGE_BAD_GATEWAY, HttpStatusCodes.STATUS_CODE_BAD_GATEWAY);
8990
}
9091

9192
// Builds a fake GoogleJsonResponseException for testing API error handling.
92-
private static HttpTransport fakeGoogleJsonResponseException(
93+
private static MockHttpTransport fakeGoogleJsonResponseException(
9394
String reason, String message, int statusCode) {
9495
ErrorInfo errorInfo = new ErrorInfo();
9596
errorInfo.setReason(reason);
9697
errorInfo.setMessage(message);
9798
return fakeGoogleJsonResponseExceptionTransport(errorInfo, message, statusCode);
9899
}
99100

100-
private static HttpTransport fakeGoogleJsonResponseExceptionTransport(
101+
private static MockHttpTransport fakeGoogleJsonResponseExceptionTransport(
101102
ErrorInfo errorInfo, String message, int statusCode) {
102103
final JsonFactory jsonFactory = new GsonFactory();
103104
return new MockHttpTransport() {
@@ -117,7 +118,7 @@ public LowLevelHttpRequest buildRequest(String method, String url) throws IOExce
117118
new MockLowLevelHttpResponse()
118119
.setContent(errorResponse.toPrettyString())
119120
.setContentType(Json.MEDIA_TYPE)
120-
.setStatusCode(HttpStatusCodes.STATUS_CODE_FORBIDDEN));
121+
.setStatusCode(statusCode));
121122
}
122123
};
123124
}
@@ -131,19 +132,20 @@ public void setup() throws GeneralSecurityException {
131132
clientKeyPair = Futures.immediateFuture(TestKeys.getClientKeyPair());
132133
}
133134

134-
HttpTransport fakeSuccessHttpTransport(Duration certDuration) {
135+
MockHttpTransport fakeSuccessHttpTransport(Duration certDuration) {
135136
return fakeSuccessHttpTransport(TestKeys.getServerCertPem(), certDuration, null);
136137
}
137138

138-
HttpTransport fakeSuccessHttpTransport(Duration certDuration, String baseUrl) {
139+
MockHttpTransport fakeSuccessHttpTransport(Duration certDuration, String baseUrl) {
139140
return fakeSuccessHttpTransport(TestKeys.getServerCertPem(), certDuration, baseUrl);
140141
}
141142

142-
HttpTransport fakeSuccessHttpTransport(String serverCert, Duration certDuration) {
143+
MockHttpTransport fakeSuccessHttpTransport(String serverCert, Duration certDuration) {
143144
return fakeSuccessHttpTransport(serverCert, certDuration, null);
144145
}
145146

146-
HttpTransport fakeSuccessHttpTransport(String serverCert, Duration certDuration, String baseUrl) {
147+
MockHttpTransport fakeSuccessHttpTransport(
148+
String serverCert, Duration certDuration, String baseUrl) {
147149
final JsonFactory jsonFactory = new GsonFactory();
148150
return new MockHttpTransport() {
149151
@Override
@@ -187,4 +189,34 @@ public LowLevelHttpResponse execute() throws IOException {
187189
}
188190
};
189191
}
192+
193+
HttpTransport fakeIntermittentErrorHttpTransport(
194+
MockHttpTransport successTransport, MockHttpTransport errorTransport) {
195+
// Odd request counts get errors, Even request count gets success.
196+
ConcurrentHashMap<String, Integer> counterByUrl = new ConcurrentHashMap<>();
197+
198+
return new MockHttpTransport() {
199+
@Override
200+
public LowLevelHttpRequest buildRequest(String method, String url) {
201+
return new MockLowLevelHttpRequest() {
202+
@Override
203+
public LowLevelHttpResponse execute() throws IOException {
204+
int count =
205+
counterByUrl.compute(
206+
url,
207+
(url, val) -> {
208+
if (val == null) {
209+
return 1;
210+
}
211+
return val + 1;
212+
});
213+
if (count % 2 == 0) {
214+
return successTransport.buildRequest(method, url).execute();
215+
}
216+
return errorTransport.buildRequest(method, url).execute();
217+
}
218+
};
219+
}
220+
};
221+
}
190222
}

‎core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java

+44-8
Original file line numberDiff line numberDiff line change
@@ -337,15 +337,51 @@ public void create_throwsException_badGateway() throws IOException {
337337
TEST_MAX_REFRESH_MS,
338338
DEFAULT_SERVER_PROXY_PORT);
339339

340-
TerminalException ex =
341-
assertThrows(TerminalException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS));
340+
// If the gateway is down, then this is a temporary error, not a fatal error.
341+
RuntimeException ex =
342+
assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS));
342343

343-
assertThat(ex)
344-
.hasMessageThat()
345-
.contains(
346-
String.format(
347-
"[%s] The Google Cloud SQL Admin API failed for the project",
348-
"myProject:myRegion:NotMyInstance"));
344+
// The Connector.connect() method will timeout, and will include details about the instance
345+
// data in the test.
346+
assertThat(ex).hasMessageThat().contains("Unable to get valid instance data within");
347+
348+
assertThat(ex).hasMessageThat().contains("502");
349+
350+
assertThat(ex).hasMessageThat().contains("myProject:myRegion:NotMyInstance");
351+
}
352+
353+
@Test
354+
public void create_successfulPublicConnection_withIntermittentBadGatewayErrors()
355+
throws IOException, InterruptedException {
356+
ConnectionInfoRepositoryFactory factory =
357+
new StubConnectionInfoRepositoryFactory(
358+
fakeIntermittentErrorHttpTransport(
359+
fakeSuccessHttpTransport(Duration.ofSeconds(0)), fakeBadGatewayException()));
360+
361+
FakeSslServer sslServer = new FakeSslServer();
362+
363+
ConnectionConfig config =
364+
new ConnectionConfig.Builder()
365+
.withCloudSqlInstance("myProject:myRegion:myInstance")
366+
.withIpTypes("PRIMARY")
367+
.build();
368+
369+
int port = sslServer.start(PUBLIC_IP);
370+
371+
Connector c =
372+
new Connector(
373+
config.getConnectorConfig(),
374+
factory,
375+
stubCredentialFactoryProvider.getInstanceCredentialFactory(config.getConnectorConfig()),
376+
defaultExecutor,
377+
clientKeyPair,
378+
10,
379+
TEST_MAX_REFRESH_MS,
380+
port);
381+
382+
Socket socket = c.connect(config, TEST_MAX_REFRESH_MS);
383+
384+
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
349385
}
350386

351387
@Test

‎core/src/test/java/com/google/cloud/sql/core/DefaultAccessTokenSupplierTest.java

+14-27
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import com.google.auth.oauth2.OAuth2Credentials;
3838
import com.google.cloud.sql.CredentialFactory;
3939
import java.io.IOException;
40-
import java.time.Duration;
4140
import java.time.Instant;
4241
import java.time.temporal.ChronoUnit;
4342
import java.util.Collections;
@@ -75,8 +74,7 @@ public AccessToken refreshAccessToken() throws IOException {
7574

7675
@Test
7776
public void testEmptyTokenOnEmptyCredentials() throws IOException {
78-
DefaultAccessTokenSupplier supplier =
79-
new DefaultAccessTokenSupplier(null, 1, Duration.ofMillis(10));
77+
DefaultAccessTokenSupplier supplier = new DefaultAccessTokenSupplier(null);
8078
assertThat(supplier.get()).isEqualTo(Optional.empty());
8179
}
8280

@@ -100,8 +98,7 @@ public AccessToken refreshAccessToken() throws IOException {
10098
};
10199

102100
DefaultAccessTokenSupplier supplier =
103-
new DefaultAccessTokenSupplier(
104-
new GoogleCredentialsFactory(googleCredentials), 1, Duration.ofMillis(10));
101+
new DefaultAccessTokenSupplier(new GoogleCredentialsFactory(googleCredentials));
105102
Optional<AccessToken> token = supplier.get();
106103

107104
assertThat(token.isPresent()).isTrue();
@@ -127,8 +124,7 @@ public HttpRequestInitializer create() {
127124
}
128125
};
129126

130-
DefaultAccessTokenSupplier supplier =
131-
new DefaultAccessTokenSupplier(badFactory, 1, Duration.ofMillis(10));
127+
DefaultAccessTokenSupplier supplier = new DefaultAccessTokenSupplier(badFactory);
132128
RuntimeException ex = assertThrows(RuntimeException.class, supplier::get);
133129
assertThat(ex).hasMessageThat().contains("Unsupported credentials of type");
134130
}
@@ -153,11 +149,10 @@ public AccessToken refreshAccessToken() throws IOException {
153149
};
154150

155151
DefaultAccessTokenSupplier supplier =
156-
new DefaultAccessTokenSupplier(
157-
new GoogleCredentialsFactory(expiredGoogleCredentials), 1, Duration.ofMillis(10));
152+
new DefaultAccessTokenSupplier(new GoogleCredentialsFactory(expiredGoogleCredentials));
158153
IllegalStateException ex = assertThrows(IllegalStateException.class, supplier::get);
159154
assertThat(ex).hasMessageThat().contains("Error refreshing credentials");
160-
assertThat(refreshCounter.get()).isEqualTo(1);
155+
assertThat(refreshCounter.get()).isEqualTo(5);
161156
}
162157

163158
@Test
@@ -180,11 +175,10 @@ public AccessToken refreshAccessToken() throws IOException {
180175
};
181176

182177
DefaultAccessTokenSupplier supplier =
183-
new DefaultAccessTokenSupplier(
184-
new GoogleCredentialsFactory(refreshGetsExpiredToken), 1, Duration.ofMillis(10));
178+
new DefaultAccessTokenSupplier(new GoogleCredentialsFactory(refreshGetsExpiredToken));
185179
IllegalStateException ex = assertThrows(IllegalStateException.class, supplier::get);
186180
assertThat(ex).hasMessageThat().contains("expiration time is in the past");
187-
assertThat(refreshCounter.get()).isEqualTo(1);
181+
assertThat(refreshCounter.get()).isEqualTo(5);
188182
}
189183

190184
@Test
@@ -206,8 +200,7 @@ public AccessToken refreshAccessToken() throws IOException {
206200
};
207201

208202
DefaultAccessTokenSupplier supplier =
209-
new DefaultAccessTokenSupplier(
210-
new GoogleCredentialsFactory(refreshableCredentials), 1, Duration.ofMillis(10));
203+
new DefaultAccessTokenSupplier(new GoogleCredentialsFactory(refreshableCredentials));
211204
Optional<AccessToken> token = supplier.get();
212205

213206
assertThat(token.isPresent()).isTrue();
@@ -239,8 +232,7 @@ public AccessToken refreshAccessToken() throws IOException {
239232
};
240233

241234
DefaultAccessTokenSupplier supplier =
242-
new DefaultAccessTokenSupplier(
243-
new GoogleCredentialsFactory(refreshableCredentials), 3, Duration.ofMillis(10));
235+
new DefaultAccessTokenSupplier(new GoogleCredentialsFactory(refreshableCredentials));
244236
Optional<AccessToken> token = supplier.get();
245237

246238
assertThat(token.isPresent()).isTrue();
@@ -271,8 +263,7 @@ public GoogleCredentials createScoped(String... scopes) {
271263
public void throwsErrorForWrongCredentialType() {
272264
OAuth2Credentials creds = OAuth2Credentials.create(new AccessToken("abc", null));
273265
DefaultAccessTokenSupplier supplier =
274-
new DefaultAccessTokenSupplier(
275-
new Oauth2BadCredentialFactory(creds), 1, Duration.ofMillis(10));
266+
new DefaultAccessTokenSupplier(new Oauth2BadCredentialFactory(creds));
276267
RuntimeException ex = assertThrows(RuntimeException.class, supplier::get);
277268

278269
assertThat(ex)
@@ -291,8 +282,7 @@ public GoogleCredentials createScoped(String... scopes) {
291282
}
292283
};
293284
DefaultAccessTokenSupplier supplier =
294-
new DefaultAccessTokenSupplier(
295-
new GoogleCredentialsFactory(creds), 1, Duration.ofMillis(10));
285+
new DefaultAccessTokenSupplier(new GoogleCredentialsFactory(creds));
296286
RuntimeException ex = assertThrows(RuntimeException.class, supplier::get);
297287

298288
assertThat(ex).hasMessageThat().contains("Access Token has length of zero");
@@ -317,8 +307,7 @@ public AccessToken refreshAccessToken() throws IOException {
317307
};
318308

319309
DefaultAccessTokenSupplier supplier =
320-
new DefaultAccessTokenSupplier(
321-
new GoogleCredentialsFactory(refreshableCredentials), 1, Duration.ofMillis(10));
310+
new DefaultAccessTokenSupplier(new GoogleCredentialsFactory(refreshableCredentials));
322311
RuntimeException ex = assertThrows(RuntimeException.class, supplier::get);
323312

324313
assertThat(ex).hasMessageThat().contains("Access Token expiration time is in the past");
@@ -368,8 +357,7 @@ public LowLevelHttpResponse execute() throws IOException {
368357
credential.setExpirationTimeMilliseconds(future.toEpochMilli());
369358

370359
DefaultAccessTokenSupplier supplier =
371-
new DefaultAccessTokenSupplier(
372-
new Oauth2CredentialFactory(credential), 1, Duration.ofMillis(10));
360+
new DefaultAccessTokenSupplier(new Oauth2CredentialFactory(credential));
373361
Optional<AccessToken> token = supplier.get();
374362

375363
assertThat(token.isPresent()).isTrue();
@@ -438,8 +426,7 @@ public void intercept(HttpRequest request) throws IOException {
438426
credential.setExpirationTimeMilliseconds(past.toEpochMilli());
439427

440428
DefaultAccessTokenSupplier supplier =
441-
new DefaultAccessTokenSupplier(
442-
new Oauth2CredentialFactory(credential), 1, Duration.ofMillis(10));
429+
new DefaultAccessTokenSupplier(new Oauth2CredentialFactory(credential));
443430
Optional<AccessToken> token = supplier.get();
444431

445432
assertThat(token.isPresent()).isTrue();

‎core/src/test/java/com/google/cloud/sql/core/RetryingCallableTest.java

+32-28
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import static com.google.common.truth.Truth.assertThat;
2020

21-
import java.time.Duration;
2221
import java.util.concurrent.atomic.AtomicInteger;
2322
import org.junit.Assert;
2423
import org.junit.Test;
@@ -27,30 +26,12 @@ public class RetryingCallableTest {
2726
@Test
2827
public void testConstructorIllegalArguments() throws Exception {
2928
// Callable must not be null
30-
Assert.assertThrows(
31-
IllegalArgumentException.class,
32-
() -> new RetryingCallable<>(null, 1, Duration.ofMillis(1)));
33-
34-
// Must have a positive retryCount
35-
Assert.assertThrows(
36-
IllegalArgumentException.class,
37-
() -> new RetryingCallable<>(() -> null, 0, Duration.ofMillis(1)));
38-
Assert.assertThrows(
39-
IllegalArgumentException.class,
40-
() -> new RetryingCallable<>(() -> null, -1, Duration.ofMillis(1)));
41-
42-
// Must have a positive duration
43-
Assert.assertThrows(
44-
IllegalArgumentException.class,
45-
() -> new RetryingCallable<>(() -> null, 2, Duration.ofMillis(-5)));
46-
Assert.assertThrows(
47-
IllegalArgumentException.class,
48-
() -> new RetryingCallable<>(() -> null, 2, Duration.ofMillis(0)));
29+
Assert.assertThrows(IllegalArgumentException.class, () -> new RetryingCallable<>(null));
4930
}
5031

5132
@Test
5233
public void testNoRetryRequired() throws Exception {
53-
RetryingCallable<Integer> r = new RetryingCallable<>(() -> 1, 5, Duration.ofMillis(100));
34+
RetryingCallable<Integer> r = new RetryingCallable<>(() -> 1);
5435
int v = r.call();
5536
assertThat(v).isEqualTo(1);
5637
}
@@ -63,9 +44,7 @@ public void testAlwaysFails() {
6344
() -> {
6445
counter.incrementAndGet();
6546
throw new Exception("nope");
66-
},
67-
3,
68-
Duration.ofMillis(100));
47+
});
6948

7049
try {
7150
r.call();
@@ -74,7 +53,7 @@ public void testAlwaysFails() {
7453
// Expected to throw an exception
7554
}
7655

77-
assertThat(counter.get()).isEqualTo(3);
56+
assertThat(counter.get()).isEqualTo(5);
7857
}
7958

8059
@Test
@@ -88,12 +67,37 @@ public void testRetrySucceedsAfterFailures() throws Exception {
8867
throw new Exception("nope");
8968
}
9069
return i;
91-
},
92-
5,
93-
Duration.ofMillis(100));
70+
});
9471

9572
int v = r.call();
9673
assertThat(counter.get()).isEqualTo(3);
9774
assertThat(v).isEqualTo(3);
9875
}
76+
77+
@Test
78+
public void testRetryStopsAfterFatalException() throws Exception {
79+
final AtomicInteger counter = new AtomicInteger();
80+
RetryingCallable<Integer> r =
81+
new RetryingCallable<Integer>(
82+
() -> {
83+
int i = counter.incrementAndGet();
84+
if (i < 3) {
85+
throw new Exception("nope");
86+
}
87+
return i;
88+
}) {
89+
@Override
90+
protected boolean isFatalException(Exception e) {
91+
return true;
92+
}
93+
};
94+
95+
try {
96+
r.call();
97+
Assert.fail("got no exception, wants an exception to be thrown");
98+
} catch (Exception e) {
99+
// Expected to throw an exception
100+
}
101+
assertThat(counter.get()).isEqualTo(1);
102+
}
99103
}

0 commit comments

Comments
 (0)
Please sign in to comment.