Skip to content

Commit

Permalink
Support for SSLContext configuration on StandardWebSocketClient
Browse files Browse the repository at this point in the history
Closes gh-30680
  • Loading branch information
jhoeller committed Dec 28, 2023
1 parent 989625d commit f5b4f7d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -114,7 +114,7 @@ private Mono<Void> executeInternal(URI url, HttpHeaders requestHeaders, WebSocke
return Mono.error(ex);
}
})
.subscribeOn(Schedulers.boundedElastic()); // connectToServer is blocking
.subscribeOn(Schedulers.boundedElastic()); // connectToServer is blocking
}

private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler,
Expand All @@ -130,24 +130,38 @@ private HandshakeInfo createHandshakeInfo(URI url, DefaultConfigurator configura
return new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol);
}

/**
* Create the {@link StandardWebSocketSession} for the given Jakarta WebSocket Session.
* @see #bufferFactory()
*/
protected StandardWebSocketSession createWebSocketSession(
Session session, HandshakeInfo info, Sinks.Empty<Void> completionSink) {

return new StandardWebSocketSession(
session, info, DefaultDataBufferFactory.sharedInstance, completionSink);
return new StandardWebSocketSession(session, info, bufferFactory(), completionSink);
}

/**
* Return the {@link DataBufferFactory} to use.
* @see #createWebSocketSession
*/
protected DataBufferFactory bufferFactory() {
return DefaultDataBufferFactory.sharedInstance;
}

private ClientEndpointConfig createEndpointConfig(Configurator configurator, List<String> subProtocols) {
/**
* Create the {@link ClientEndpointConfig} for the given configurator.
* Can be overridden to add extensions or an SSL context.
* @param configurator the configurator to apply
* @param subProtocols the preferred sub-protocols
* @since 6.1.3
*/
protected ClientEndpointConfig createEndpointConfig(Configurator configurator, List<String> subProtocols) {
return ClientEndpointConfig.Builder.create()
.configurator(configurator)
.preferredSubprotocols(subProtocols)
.build();
}

protected DataBufferFactory bufferFactory() {
return DefaultDataBufferFactory.sharedInstance;
}


private static final class DefaultConfigurator extends Configurator {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,8 @@
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;

import javax.net.ssl.SSLContext;

import jakarta.websocket.ClientEndpointConfig;
import jakarta.websocket.ClientEndpointConfig.Configurator;
import jakarta.websocket.ContainerProvider;
Expand All @@ -54,6 +56,7 @@
* A WebSocketClient based on the standard Jakarta WebSocket API.
*
* @author Rossen Stoyanchev
* @author Juergen Hoeller
* @since 4.0
*/
public class StandardWebSocketClient extends AbstractWebSocketClient {
Expand All @@ -62,6 +65,9 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {

private final Map<String,Object> userProperties = new HashMap<>();

@Nullable
private SSLContext sslContext;

@Nullable
private AsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();

Expand Down Expand Up @@ -100,12 +106,29 @@ public void setUserProperties(@Nullable Map<String, Object> userProperties) {
}

/**
* The configured user properties.
* Return the configured user properties.
*/
public Map<String, Object> getUserProperties() {
return this.userProperties;
}

/**
* Set the {@link SSLContext} to use for {@link ClientEndpointConfig#getSSLContext()}.
* @since 6.1.3
*/
public void setSslContext(@Nullable SSLContext sslContext) {
this.sslContext = sslContext;
}

/**
* Return the {@link SSLContext} to use.
* @since 6.1.3
*/
@Nullable
public SSLContext getSslContext() {
return this.sslContext;
}

/**
* Set an {@link AsyncTaskExecutor} to use when opening connections.
* <p>If this property is set to {@code null}, calls to any of the
Expand Down Expand Up @@ -134,17 +157,19 @@ protected CompletableFuture<WebSocketSession> executeInternal(WebSocketHandler w
InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);

final StandardWebSocketSession session = new StandardWebSocketSession(headers,
StandardWebSocketSession session = new StandardWebSocketSession(headers,
attributes, localAddress, remoteAddress);

final ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create()
ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create()
.configurator(new StandardWebSocketClientConfigurator(headers))
.preferredSubprotocols(protocols)
.extensions(adaptExtensions(extensions)).build();
.extensions(adaptExtensions(extensions))
.sslContext(getSslContext())
.build();

endpointConfig.getUserProperties().putAll(getUserProperties());

final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);

Callable<WebSocketSession> connectTask = () -> {
this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri);
Expand All @@ -167,7 +192,7 @@ private static List<Extension> adaptExtensions(List<WebSocketExtension> extensio
return result;
}

private InetAddress getLocalHost() {
private static InetAddress getLocalHost() {
try {
return InetAddress.getLocalHost();
}
Expand All @@ -176,7 +201,7 @@ private InetAddress getLocalHost() {
}
}

private int getPort(URI uri) {
private static int getPort(URI uri) {
if (uri.getPort() == -1) {
String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH);
return ("wss".equals(scheme) ? 443 : 80);
Expand Down

0 comments on commit f5b4f7d

Please sign in to comment.