Skip to content

Commit

Permalink
Fix websocket reconnect race condition (#7815)
Browse files Browse the repository at this point in the history
  • Loading branch information
yschimke committed May 13, 2023
1 parent 4fc5185 commit f62fd47
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
println("After delay: " + connectionPool.connectionCount())
}

assertEquals(0, connectionPool.connectionCount())
assertEquals(0, connectionPool.connectionCount()) {
"Still ${connectionPool.connectionCount()} connections open"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ class RealWebSocket(
checkUpgradeSuccess(response, exchange)
streams = exchange!!.newWebSocketStreams()
} catch (e: IOException) {
exchange?.webSocketUpgradeFailed()
failWebSocket(e, response)
response.closeQuietly()
exchange?.webSocketUpgradeFailed()
return
}

Expand Down
98 changes: 92 additions & 6 deletions okhttp/src/jvmTest/java/okhttp3/internal/ws/WebSocketHttpTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.net.ProtocolException;
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand All @@ -33,6 +35,7 @@
import mockwebserver3.SocketPolicy;
import mockwebserver3.SocketPolicy.KeepOpen;
import mockwebserver3.SocketPolicy.NoResponse;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import okhttp3.OkHttpClientTestRule;
import okhttp3.Protocol;
Expand Down Expand Up @@ -351,10 +354,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Upgrade", "websocket")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Connection' header value 'Upgrade' but was 'null'");

webSocket.cancel();
}

@Test public void wrongConnectionHeader() throws IOException {
Expand All @@ -364,10 +373,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Connection", "Downgrade")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Connection' header value 'Upgrade' but was 'Downgrade'");

webSocket.cancel();
}

@Test public void missingUpgradeHeader() throws IOException {
Expand All @@ -376,10 +391,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Connection", "Upgrade")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Upgrade' header value 'websocket' but was 'null'");

webSocket.cancel();
}

@Test public void wrongUpgradeHeader() throws IOException {
Expand All @@ -389,10 +410,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Upgrade", "Pepsi")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Upgrade' header value 'websocket' but was 'Pepsi'");

webSocket.cancel();
}

@Test public void missingMagicHeader() throws IOException {
Expand All @@ -401,10 +428,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Connection", "Upgrade")
.setHeader("Upgrade", "websocket")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'");

webSocket.cancel();
}

@Test public void wrongMagicHeader() throws IOException {
Expand All @@ -414,10 +447,16 @@ private OkHttpClientTestRule configureClientTestRule() {
.setHeader("Upgrade", "websocket")
.setHeader("Sec-WebSocket-Accept", "magic")
.build());
newWebSocket();
webServer.enqueue(new MockResponse.Builder()
.socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE)
.build());

RealWebSocket webSocket = newWebSocket();

clientListener.assertFailure(101, null, ProtocolException.class,
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'");

webSocket.cancel();
}

@Test public void clientIncludesForbiddenHeader() throws IOException {
Expand Down Expand Up @@ -868,6 +907,53 @@ private OkHttpClientTestRule configureClientTestRule() {
webSocket.close(1000, null);
}

/** https://github.com/square/okhttp/issues/7768 */
@Test public void reconnectingToNonWebSocket() throws InterruptedException {
// Async test is problematic
client = this.client.newBuilder()
.connectionPool(new ConnectionPool())
.build();

for (int i = 0; i < 30; i++) {
webServer.enqueue(new MockResponse.Builder()
.bodyDelay(100, TimeUnit.MILLISECONDS)
.body("Wrong endpoint")
.code(401)
.build());
}

Request request = new Request.Builder()
.url(webServer.url("/"))
.build();

CountDownLatch attempts = new CountDownLatch(20);

List<WebSocket> webSockets = new ArrayList<>();

WebSocketListener reconnectOnFailure = new WebSocketListener() {
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
if (attempts.getCount() > 0) {
clientListener.setNextEventDelegate(this);
webSockets.add(client.newWebSocket(request, clientListener));
attempts.countDown();
}
}
};

clientListener.setNextEventDelegate(reconnectOnFailure);

webSockets.add(client.newWebSocket(request, clientListener));

attempts.await();

for (WebSocket webSocket: webSockets) {
webSocket.cancel();
}
client.dispatcher().cancelAll();
client.connectionPool().evictAll();
}

@Test public void compressedMessages() throws Exception {
successfulExtensions("permessage-deflate");
}
Expand Down

0 comments on commit f62fd47

Please sign in to comment.