Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Considering making the RestTemplate configurable in ClientRegistrations #14777

Closed
limo520 opened this issue Mar 19, 2024 · 2 comments
Closed
Assignees
Labels
status: declined A suggestion or change that we don't feel we should currently apply type: enhancement A general enhancement

Comments

@limo520
Copy link

limo520 commented Mar 19, 2024

For example, configure SSLContext when SSL is enabled for authorization server. Currently there is no chance to do that.

@limo520 limo520 added status: waiting-for-triage An issue we've not yet triaged type: enhancement A general enhancement labels Mar 19, 2024
@sjohnr
Copy link
Member

sjohnr commented Mar 21, 2024

In the 6.2 (released) and 6.3 (not yet released, see latest milestone), there is a new feature available (see gh-11783 and gh-13763) which allows you to more easily configure a RestTemplate for OAuth2 Client components. See docs.

However, regarding ClientRegistrations please see this comment:

ClientRegistrations is intended to be used as a utility/convenience class. It was designed to fulfill most use cases, however, it may not be suitable for certain use cases. For example, if the internal network traffic must be routed through a Proxy, you can bypass discovery by configuring the authorization-uri and token-uri property instead of the issuer-uri property.

NOTE: The underlying HTTP Client used in ClientRegistrations was purposely encapsulated and there is no plan to expose it.

I'm going to close this issue with the above explanation.

@sjohnr sjohnr closed this as completed Mar 21, 2024
@sjohnr sjohnr added status: declined A suggestion or change that we don't feel we should currently apply and removed status: waiting-for-triage An issue we've not yet triaged labels Mar 21, 2024
@sjohnr sjohnr self-assigned this Mar 21, 2024
@krezovic
Copy link
Contributor

I am sharing my code I wrote to ease customization of RestTemplate and reactive WebClient

import lombok.Data;

import org.springframework.boot.autoconfigure.ssl.PropertiesSslBundle;
import org.springframework.boot.autoconfigure.ssl.SslProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.bind.Binder;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.core.env.Environment;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;

import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.Optional;

/** Common properties for communication with external services using REST clients. */
@Data
@ConfigurationProperties(ExternalRestClientProperties.PREFIX)
public class ExternalRestClientProperties {
    public static final String PREFIX = "gateway.external-rest-client";

    private static final long DEFAULT_TIMEOUT_SECONDS = 30L;

    /** Connection timeout. */
    private Duration connectTimeout = Duration.ofSeconds(DEFAULT_TIMEOUT_SECONDS);

    /** Read timeout. */
    private Duration readTimeout = Duration.ofSeconds(DEFAULT_TIMEOUT_SECONDS);

    /** HTTP Proxy configuration. */
    private ProxyConfig proxy = new ProxyConfig();

    /** SSL/TLS configuration. */
    private SslConfig ssl = new SslConfig();

    /**
     * @return configured or default (30s) value for connect timeout.
     */
    @NonNull
    public Duration getOrDefaultConnectTimeout() {
        return Optional.ofNullable(connectTimeout)
                .orElse(Duration.ofSeconds(DEFAULT_TIMEOUT_SECONDS));
    }

    /**
     * @return configured or default (30s) value for read timeout.
     */
    @NonNull
    public Duration getOrDefaultReadTimeout() {
        return Optional.ofNullable(readTimeout).orElse(Duration.ofSeconds(DEFAULT_TIMEOUT_SECONDS));
    }

    /** HTTP Proxy configuration. */
    @Data
    public static class ProxyConfig {
        /** HTTP Proxy Host. */
        private String host;

        /** HTTP Proxy Port. */
        private Integer port;

        /**
         * @return a {@link InetSocketAddress} instance for provided {@link #host} and {@link #port}
         *     or null if either of them is null.
         */
        @Nullable
        public InetSocketAddress toInetSocketAddress() {
            if (!ObjectUtils.isEmpty(host) && port != null) {
                Assert.isTrue(port > 0 && port <= 65535, "Proxy port must be between 1 and 65535");
                return new InetSocketAddress(host, port);
            }

            return null;
        }
    }

    /** SSL/TLS configuration. */
    @Data
    public static class SslConfig {
        /** SSL Bundle Name as configured via spring.ssl configuration properties. */
        private String bundle;

        /** Skip verification of insecure certificates. */
        private boolean insecureSkipVerify = false;

        /**
         * Getter to extract a {@link SslBundle} from provided {@link SslBundles}.
         *
         * @param sslBundles to extract {@link SslBundle} from
         * @return {@link SslBundle} or null if {@link #bundle} is null
         */
        @Nullable
        public SslBundle toSslBundle(@NonNull SslBundles sslBundles) {
            if (insecureSkipVerify || ObjectUtils.isEmpty(bundle)) {
                return null;
            }

            return sslBundles.getBundle(bundle);
        }

        /**
         * Getter to extract a {@link SslBundle} from provided {@link Environment}. It should only
         * be used for early bootstrap phase where no {@link SslBundles} instance has been
         * configured.
         *
         * @param environment to extract {@link SslBundle} from
         * @return {@link SslBundle} or null if {@link #bundle} is null
         */
        @Nullable
        public SslBundle toSslBundle(@NonNull Environment environment) {
            if (insecureSkipVerify || ObjectUtils.isEmpty(bundle)) {
                return null;
            }

            var sslProperties =
                    Binder.get(environment)
                            .bind("spring.ssl", SslProperties.class)
                            .orElseGet(SslProperties::new);

            var jksBundles = sslProperties.getBundle().getJks();
            var pemBundles = sslProperties.getBundle().getPem();

            if (jksBundles.containsKey(bundle)) {
                return PropertiesSslBundle.get(jksBundles.get(bundle));
            } else if (pemBundles.containsKey(bundle)) {
                return PropertiesSslBundle.get(pemBundles.get(bundle));
            } else {
                throw new IllegalArgumentException("Cannot find SSL bundle with name " + bundle);
            }
        }
    }

    /**
     * Creates a {@link ExternalRestClientProperties} instance using {@link Environment} from the
     * default configuration prefix.
     *
     * @param environment to extract {@link ExternalRestClientProperties} value from
     * @return created {@link ExternalRestClientProperties}
     */
    @NonNull
    public static ExternalRestClientProperties create(@NonNull Environment environment) {
        return create(environment, PREFIX);
    }

    /**
     * Creates a {@link ExternalRestClientProperties} instance using {@link Environment} from
     * provided configuration prefix.
     *
     * @param environment to extract {@link ExternalRestClientProperties} value from
     * @param prefix a prefix in the configuration hierarchy for the properties instance
     * @return created {@link ExternalRestClientProperties}
     */
    @NonNull
    public static ExternalRestClientProperties create(
            @NonNull Environment environment, @NonNull String prefix) {
        return Binder.get(environment)
                .bind(prefix, ExternalRestClientProperties.class)
                .orElseGet(ExternalRestClientProperties::new);
    }
}
import lombok.extern.slf4j.Slf4j;

import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.web.client.RestTemplate;

import java.net.ProxySelector;
import java.net.Socket;
import java.net.http.HttpClient;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedTrustManager;

/** Utility for creating RestTemplate for communication with external services. */
@Slf4j
public final class ExternalRestTemplateUtils {
    private ExternalRestTemplateUtils() {}

    @SuppressWarnings("java:S4830")
    private static final TrustManager INSECURE_TRUST_MANAGER =
            new X509ExtendedTrustManager() {
                @Override
                public X509Certificate[] getAcceptedIssuers() {
                    return new X509Certificate[] {};
                }

                @Override
                public void checkClientTrusted(X509Certificate[] chain, String authType) {
                    // no-op
                }

                @Override
                public void checkServerTrusted(X509Certificate[] chain, String authType) {
                    // no-op
                }

                @Override
                public void checkClientTrusted(
                        X509Certificate[] chain, String authType, Socket socket) {
                    // no-op
                }

                @Override
                public void checkServerTrusted(
                        X509Certificate[] chain, String authType, Socket socket) {
                    // no-op
                }

                @Override
                public void checkClientTrusted(
                        X509Certificate[] chain, String authType, SSLEngine engine) {
                    // no-op
                }

                @Override
                public void checkServerTrusted(
                        X509Certificate[] chain, String authType, SSLEngine engine) {
                    // no-op
                }
            };

    /**
     * Creates a {@link RestTemplate} using {@link ExternalRestClientProperties}.
     *
     * @param properties used for configuration
     * @param sslBundles used for trust/key store configuration
     * @param builder optional preconfigured {@link RestTemplateBuilder}
     * @return created {@link RestTemplate}
     */
    public static RestTemplate restTemplate(
            @NonNull ExternalRestClientProperties properties,
            @NonNull SslBundles sslBundles,
            @Nullable RestTemplateBuilder builder) {
        var sslConfig = properties.getSsl();

        if (builder == null) {
            builder = new RestTemplateBuilder();
        }

        var sslBundle = sslConfig.toSslBundle(sslBundles);
        var requestFactory = requestFactory(properties, sslBundle);

        return builder.requestFactory(s -> requestFactory).build();
    }

    /**
     * Creates a {@link ClientHttpRequestFactory} using {@link ExternalRestClientProperties} and
     * provided {@link SslBundle}.
     *
     * @param properties used for configuration
     * @param sslBundle used for trust/key store configuration
     * @return created {@link ClientHttpRequestFactory}
     */
    public static ClientHttpRequestFactory requestFactory(
            @NonNull ExternalRestClientProperties properties, @Nullable SslBundle sslBundle) {

        var sslConfig = properties.getSsl();
        var proxyConfig = properties.getProxy();
        var proxyAddress = proxyConfig.toInetSocketAddress();
        var proxy = proxyAddress != null ? ProxySelector.of(proxyAddress) : null;

        var builder = HttpClient.newBuilder();

        if (properties.getConnectTimeout() != null) {
            builder.connectTimeout(properties.getConnectTimeout());
        }

        if (sslBundle != null) {
            builder.sslContext(sslBundle.createSslContext());
        } else if (sslConfig.isInsecureSkipVerify()) {
            log.warn("Creating RestTemplate instance without TLS certificate verification");
            builder.sslContext(insecureSslContext());
        }

        if (proxy != null) {
            builder.proxy(proxy);
        }

        var factory = new JdkClientHttpRequestFactory(builder.build());

        if (properties.getReadTimeout() != null) {
            factory.setReadTimeout(properties.getReadTimeout());
        }

        return factory;
    }

    private static SSLContext insecureSslContext() {
        try {
            var sslContext = SSLContext.getInstance("TLS");
            sslContext.init(null, new TrustManager[] {INSECURE_TRUST_MANAGER}, new SecureRandom());
            return sslContext;
        } catch (Exception e) {
            log.trace("Caught exception", e);
            throw new IllegalStateException(e);
        }
    }
}
import io.netty.channel.ChannelOption;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;

import lombok.extern.slf4j.Slf4j;

import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.boot.ssl.SslManagerBundle;
import org.springframework.boot.ssl.SslOptions;
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.function.ThrowingConsumer;
import org.springframework.web.reactive.function.client.WebClient;

import reactor.netty.http.client.HttpClient;
import reactor.netty.tcp.SslProvider;
import reactor.netty.transport.ProxyProvider;

import java.net.InetSocketAddress;
import java.util.function.Consumer;

/** Utility for creating WebClient for communication with external services. */
@Slf4j
public final class ExternalWebClientUtils {
    private ExternalWebClientUtils() {}

    /**
     * Creates a {@link WebClient} using {@link ExternalRestClientProperties}.
     *
     * @param properties used for configuration
     * @param sslBundles used for trust/key store configuration
     * @param builder optional preconfigured {@link WebClient.Builder}
     * @return created {@link WebClient}
     */
    public static WebClient webClient(
            @NonNull ExternalRestClientProperties properties,
            @NonNull SslBundles sslBundles,
            @Nullable WebClient.Builder builder) {
        var sslConfig = properties.getSsl();

        if (builder == null) {
            builder = WebClient.builder();
        }

        var sslBundle = sslConfig.toSslBundle(sslBundles);
        var connector = createConnector(properties, sslBundle);

        return builder.clientConnector(connector).build();
    }

    private static ReactorClientHttpConnector createConnector(
            @NonNull ExternalRestClientProperties properties, @Nullable SslBundle sslBundle) {

        var sslConfig = properties.getSsl();
        var proxyConfig = properties.getProxy();
        var proxyAddress = proxyConfig.toInetSocketAddress();

        HttpClient client =
                HttpClient.create()
                        .compress(true)
                        .responseTimeout(properties.getOrDefaultReadTimeout())
                        .option(
                                ChannelOption.CONNECT_TIMEOUT_MILLIS,
                                (int) properties.getOrDefaultConnectTimeout().toMillis());

        if (sslConfig.isInsecureSkipVerify() || sslBundle != null) {
            client = client.secure(ssl(sslConfig, sslBundle));
        }

        if (proxyAddress != null) {
            client = client.proxy(proxy(proxyAddress));
        }

        return new ReactorClientHttpConnector(client);
    }

    private static Consumer<ProxyProvider.TypeSpec> proxy(InetSocketAddress proxyAddress) {
        return spec -> spec.type(ProxyProvider.Proxy.HTTP).address(proxyAddress);
    }

    private static Consumer<SslProvider.SslContextSpec> ssl(
            ExternalRestClientProperties.SslConfig sslConfig, SslBundle sslBundle) {
        if (sslConfig.isInsecureSkipVerify()) {
            log.warn("Creating WebClient instance without TLS certificate verification");
            return spec -> spec.sslContext(insecureSslContext());
        }

        return ThrowingConsumer.of(
                spec -> {
                    SslOptions options = sslBundle.getOptions();
                    SslManagerBundle managers = sslBundle.getManagers();
                    SslContextBuilder builder =
                            SslContextBuilder.forClient()
                                    .keyManager(managers.getKeyManagerFactory())
                                    .trustManager(managers.getTrustManagerFactory())
                                    .ciphers(SslOptions.asSet(options.getCiphers()))
                                    .protocols(options.getEnabledProtocols());

                    spec.sslContext(builder.build());
                });
    }

    private static SslContext insecureSslContext() {
        try {
            return SslContextBuilder.forClient()
                    .trustManager(InsecureTrustManagerFactory.INSTANCE)
                    .build();
        } catch (Exception e) {
            log.trace("Caught exception", e);
            throw new IllegalStateException(e);
        }
    }
}
import org.springframework.core.env.Environment;
import org.springframework.lang.NonNull;
import org.springframework.security.oauth2.client.registration.ClientRegistrations;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.client.RestTemplate;

/** Utilities for configuring Spring Security to use External REST client(s). */
public class ExternalSpringSecurityClientUtils {
    private ExternalSpringSecurityClientUtils() {}

    /**
     * Performs a "monkey patching" on {@link ClientRegistrations} internal {@link RestTemplate}
     * instance.
     *
     * @param environment the environment that will be used to construct {@link
     *     ExternalRestClientProperties}
     */
    public static void patchClientRegistrationsRestTemplate(@NonNull Environment environment) {
        var restField = ReflectionUtils.findField(ClientRegistrations.class, "rest");
        Assert.notNull(restField, "ClientRegistrations class should have a field 'rest'");
        ReflectionUtils.makeAccessible(restField);
        var rest = (RestTemplate) ReflectionUtils.getField(restField, null);
        Assert.notNull(rest, "ClientRegistrations class 'rest' field should be a static instance");

        var properties = ExternalRestClientProperties.create(environment);
        var sslBundle = properties.getSsl().toSslBundle(environment);
        var requestFactory = ExternalRestTemplateUtils.requestFactory(properties, sslBundle);

        rest.setRequestFactory(requestFactory);
    }
}

As a last step, I call the final method from an AutoConfiguration class that runs "before" Spring Security (I use reactive stack, so adjust accordingly):

import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.security.oauth2.client.reactive.ReactiveOAuth2ClientAutoConfiguration;
import org.springframework.boot.autoconfigure.security.oauth2.resource.reactive.ReactiveOAuth2ResourceServerAutoConfiguration;
import org.springframework.core.env.Environment;

@AutoConfiguration(
        before = {
            ReactiveOAuth2ClientAutoConfiguration.class,
            ReactiveOAuth2ResourceServerAutoConfiguration.class
        })
public class ClientRegistrationsRestClientAutoConfiguration {
    ClientRegistrationsRestClientAutoConfiguration(Environment environment) {
        ExternalSpringSecurityClientUtils.patchClientRegistrationsRestTemplate(environment);
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: declined A suggestion or change that we don't feel we should currently apply type: enhancement A general enhancement
Projects
None yet
Development

No branches or pull requests

3 participants