Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix websocket reconnect race condition #7815

Merged
merged 18 commits into from
May 13, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
println("After delay: " + connectionPool.connectionCount())
}

assertEquals(0, connectionPool.connectionCount())
assertEquals(0, connectionPool.connectionCount()) {
"Still ${connectionPool.connectionCount()} connections open"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ class RealWebSocket(
checkUpgradeSuccess(response, exchange)
streams = exchange!!.newWebSocketStreams()
} catch (e: IOException) {
exchange?.webSocketUpgradeFailed()
failWebSocket(e, response)
response.closeQuietly()
exchange?.webSocketUpgradeFailed()
return
}

Expand Down
98 changes: 92 additions & 6 deletions okhttp/src/jvmTest/java/okhttp3/internal/ws/WebSocketHttpTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.net.ProtocolException;
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand All @@ -33,6 +35,7 @@
import mockwebserver3.SocketPolicy;
import mockwebserver3.SocketPolicy.KeepOpen;
import mockwebserver3.SocketPolicy.NoResponse;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import okhttp3.OkHttpClientTestRule;
import okhttp3.Protocol;
Expand Down Expand Up @@ -351,10 +354,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Upgrade", "websocket")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Connection' header value 'Upgrade' but was 'null'");

webSocket.cancel();
}

@Test public void wrongConnectionHeader() throws IOException {
Expand All @@ -364,10 +373,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Connection", "Downgrade")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Connection' header value 'Upgrade' but was 'Downgrade'");

webSocket.cancel();
}

@Test public void missingUpgradeHeader() throws IOException {
Expand All @@ -376,10 +391,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Connection", "Upgrade")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Upgrade' header value 'websocket' but was 'null'");

webSocket.cancel();
}

@Test public void wrongUpgradeHeader() throws IOException {
Expand All @@ -389,10 +410,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Upgrade", "Pepsi")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Upgrade' header value 'websocket' but was 'Pepsi'");

webSocket.cancel();
}

@Test public void missingMagicHeader() throws IOException {
Expand All @@ -401,10 +428,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Connection", "Upgrade")
.setHeader("Upgrade", "websocket")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'");

webSocket.cancel();
}

@Test public void wrongMagicHeader() throws IOException {
Expand All @@ -414,10 +447,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Upgrade", "websocket")
.setHeader("Sec-WebSocket-Accept", "magic")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'");

webSocket.cancel();
}

@Test public void clientIncludesForbiddenHeader() throws IOException {
Expand Down Expand Up @@ -868,6 +907,53 @@ private OkHttpClientTestRule configureClientTestRule() {
webSocket.close(1000, null);
}

/** https://github.com/square/okhttp/issues/7768 */
@Test public void reconnectingToNonWebSocket() throws InterruptedException {
// Async test is problematic
client = this.client.newBuilder()
.connectionPool(new ConnectionPool())
.build();

for (int i = 0; i < 30; i++) {
webServer.enqueue(new MockResponse.Builder()
.bodyDelay(100, TimeUnit.MILLISECONDS)
.body("Wrong endpoint")
.code(401)
.build());
}

Request request = new Request.Builder()
.url(webServer.url("/"))
.build();

CountDownLatch attempts = new CountDownLatch(20);

List<WebSocket> webSockets = new ArrayList<>();

WebSocketListener reconnectOnFailure = new WebSocketListener() {
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
if (attempts.getCount() > 0) {
clientListener.setNextEventDelegate(this);
webSockets.add(client.newWebSocket(request, clientListener));
attempts.countDown();
}
}
};

clientListener.setNextEventDelegate(reconnectOnFailure);

webSockets.add(client.newWebSocket(request, clientListener));

attempts.await();

for (WebSocket webSocket: webSockets) {
webSocket.cancel();
}
client.dispatcher().cancelAll();
client.connectionPool().evictAll();
}

@Test public void compressedMessages() throws Exception {
successfulExtensions("permessage-deflate");
}
Expand Down