Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[4.x] Fix websocket reconnect race condition (#7815) #7817

Merged
merged 4 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class OkHttpClientTestRule : TestRule {
println("After delay: " + connectionPool.connectionCount())
}

assertEquals(0, connectionPool.connectionCount())
connectionPool.evictAll()
assertEquals("Still ${connectionPool.connectionCount()} connections open", 0, connectionPool.connectionCount())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ class RealWebSocket(
checkUpgradeSuccess(response, exchange)
streams = exchange!!.newWebSocketStreams()
} catch (e: IOException) {
exchange?.webSocketUpgradeFailed()
failWebSocket(e, response)
response.closeQuietly()
exchange?.webSocketUpgradeFailed()
return
}

Expand Down
81 changes: 74 additions & 7 deletions okhttp/src/test/java/okhttp3/internal/ws/WebSocketHttpTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import java.net.ProtocolException;
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import okhttp3.OkHttpClientTestRule;
import okhttp3.Protocol;
Expand Down Expand Up @@ -315,11 +318,16 @@ private OkHttpClientTestRule configureClientTestRule() {
webServer.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Upgrade", "websocket")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
newWebSocket();
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
);
webServer.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START));

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Connection' header value 'Upgrade' but was 'null'");

webSocket.cancel();
}

@Test public void wrongConnectionHeader() throws IOException {
Expand All @@ -328,21 +336,29 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Upgrade", "websocket")
.setHeader("Connection", "Downgrade")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
newWebSocket();
webServer.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START));

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Connection' header value 'Upgrade' but was 'Downgrade'");

webSocket.cancel();
}

@Test public void missingUpgradeHeader() throws IOException {
webServer.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Connection", "Upgrade")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
newWebSocket();
webServer.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START));

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Upgrade' header value 'websocket' but was 'null'");

webSocket.cancel();
}

@Test public void wrongUpgradeHeader() throws IOException {
Expand All @@ -351,21 +367,29 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Connection", "Upgrade")
.setHeader("Upgrade", "Pepsi")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
newWebSocket();
webServer.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START));

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Upgrade' header value 'websocket' but was 'Pepsi'");

webSocket.cancel();
}

@Test public void missingMagicHeader() throws IOException {
webServer.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Connection", "Upgrade")
.setHeader("Upgrade", "websocket"));
newWebSocket();
webServer.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START));

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'");

webSocket.cancel();
}

@Test public void wrongMagicHeader() throws IOException {
Expand All @@ -374,10 +398,14 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Connection", "Upgrade")
.setHeader("Upgrade", "websocket")
.setHeader("Sec-WebSocket-Accept", "magic"));
newWebSocket();
webServer.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START));

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'");

webSocket.cancel();
}

@Test public void clientIncludesForbiddenHeader() throws IOException {
Expand Down Expand Up @@ -800,6 +828,45 @@ private OkHttpClientTestRule configureClientTestRule() {
webSocket.close(1000, null);
}

/** https://github.com/square/okhttp/issues/7768 */
@Test public void reconnectingToNonWebSocket() throws InterruptedException {
for (int i = 0; i < 30; i++) {
webServer.enqueue(new MockResponse()
.setBodyDelay(100, TimeUnit.MILLISECONDS)
.setBody("Wrong endpoint")
.setResponseCode(401));
}

Request request = new Request.Builder()
.url(webServer.url("/"))
.build();

CountDownLatch attempts = new CountDownLatch(20);

List<WebSocket> webSockets = new ArrayList<>();

WebSocketListener reconnectOnFailure = new WebSocketListener() {
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
if (attempts.getCount() > 0) {
clientListener.setNextEventDelegate(this);
webSockets.add(client.newWebSocket(request, clientListener));
attempts.countDown();
}
}
};

clientListener.setNextEventDelegate(reconnectOnFailure);

webSockets.add(client.newWebSocket(request, clientListener));

attempts.await();

for (WebSocket webSocket: webSockets) {
webSocket.cancel();
}
}

@Test public void compressedMessages() throws Exception {
successfulExtensions("permessage-deflate");
}
Expand Down