Skip to content

Commit

Permalink
Configure token-exchange via a bean
Browse files Browse the repository at this point in the history
Issue gh-5199
Issue gh-11783
Closes gh-14701
  • Loading branch information
sjohnr committed Mar 7, 2024
1 parent 85c3d0a commit d6382b8
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
Expand Down Expand Up @@ -183,7 +185,8 @@ static final class OAuth2AuthorizedClientManagerRegistrar
RefreshTokenOAuth2AuthorizedClientProvider.class,
ClientCredentialsOAuth2AuthorizedClientProvider.class,
PasswordOAuth2AuthorizedClientProvider.class,
JwtBearerOAuth2AuthorizedClientProvider.class
JwtBearerOAuth2AuthorizedClientProvider.class,
TokenExchangeOAuth2AuthorizedClientProvider.class
);
// @formatter:on

Expand Down Expand Up @@ -255,6 +258,12 @@ OAuth2AuthorizedClientManager getAuthorizedClientManager() {
authorizedClientProviders.add(jwtBearerAuthorizedClientProvider);
}

OAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider = getTokenExchangeAuthorizedClientProvider(
authorizedClientProviderBeans);
if (tokenExchangeAuthorizedClientProvider != null) {
authorizedClientProviders.add(tokenExchangeAuthorizedClientProvider);
}

authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans));
authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
}
Expand Down Expand Up @@ -364,6 +373,25 @@ private OAuth2AuthorizedClientProvider getJwtBearerAuthorizedClientProvider(
return authorizedClientProvider;
}

private OAuth2AuthorizedClientProvider getTokenExchangeAuthorizedClientProvider(
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
authorizedClientProviders, TokenExchangeOAuth2AuthorizedClientProvider.class);

OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> accessTokenResponseClient = getBeanOfType(
ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
TokenExchangeGrantRequest.class));
if (accessTokenResponseClient != null) {
if (authorizedClientProvider == null) {
authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
}

authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}

return authorizedClientProvider;
}

private List<OAuth2AuthorizedClientProvider> getAdditionalAuthorizedClientProviders(
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
List<OAuth2AuthorizedClientProvider> additionalAuthorizedClientProviders = new ArrayList<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
Expand Down Expand Up @@ -76,7 +78,8 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
RefreshTokenOAuth2AuthorizedClientProvider.class,
ClientCredentialsOAuth2AuthorizedClientProvider.class,
PasswordOAuth2AuthorizedClientProvider.class,
JwtBearerOAuth2AuthorizedClientProvider.class
JwtBearerOAuth2AuthorizedClientProvider.class,
TokenExchangeOAuth2AuthorizedClientProvider.class
);
// @formatter:on

Expand Down Expand Up @@ -137,6 +140,12 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() {
authorizedClientProviders.add(jwtBearerAuthorizedClientProvider);
}

OAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider = getTokenExchangeAuthorizedClientProvider(
authorizedClientProviderBeans);
if (tokenExchangeAuthorizedClientProvider != null) {
authorizedClientProviders.add(tokenExchangeAuthorizedClientProvider);
}

authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans));
authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
}
Expand Down Expand Up @@ -245,6 +254,25 @@ private OAuth2AuthorizedClientProvider getJwtBearerAuthorizedClientProvider(
return authorizedClientProvider;
}

private OAuth2AuthorizedClientProvider getTokenExchangeAuthorizedClientProvider(
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
authorizedClientProviders, TokenExchangeOAuth2AuthorizedClientProvider.class);

OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> accessTokenResponseClient = getBeanOfType(
ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
TokenExchangeGrantRequest.class));
if (accessTokenResponseClient != null) {
if (authorizedClientProvider == null) {
authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
}

authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
}

return authorizedClientProvider;
}

private List<OAuth2AuthorizedClientProvider> getAdditionalAuthorizedClientProviders(
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
List<OAuth2AuthorizedClientProvider> additionalAuthorizedClientProviders = new ArrayList<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
Expand All @@ -70,6 +72,7 @@
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
Expand Down Expand Up @@ -327,6 +330,47 @@ private void testJwtBearerGrant() {
assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user");
}

@Test
public void authorizeWhenTokenExchangeAccessTokenResponseClientBeanThenUsed() {
this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire();
testTokenExchangeGrant();
}

@Test
public void authorizeWhenTokenExchangeAuthorizedClientProviderBeanThenUsed() {
this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
testTokenExchangeGrant();
}

private void testTokenExchangeGrant() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);

JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt());
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("auth0");
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(clientRegistration.getRegistrationId())
.principal(authentication)
.attribute(HttpServletRequest.class.getName(), this.request)
.attribute(HttpServletResponse.class.getName(), this.response)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
assertThat(authorizedClient).isNotNull();

ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());

TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue();
assertThat(grantRequest.getClientRegistration().getRegistrationId())
.isEqualTo(clientRegistration.getRegistrationId());
assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE);
assertThat(grantRequest.getSubjectToken()).isEqualTo(authentication.getToken());
}

private static OAuth2AccessToken getExpiredAccessToken() {
Instant expiresAt = Instant.now().minusSeconds(60);
Instant issuedAt = expiresAt.minus(Duration.ofDays(1));
Expand Down Expand Up @@ -376,6 +420,11 @@ OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> jwtBearerTokenResponseCli
return new MockJwtBearerClient();
}

@Bean
OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> tokenExchangeTokenResponseClient() {
return new MockTokenExchangeClient();
}

@Bean
OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService() {
return mock(DefaultOAuth2UserService.class);
Expand Down Expand Up @@ -425,6 +474,13 @@ JwtBearerOAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider() {
return authorizedClientProvider;
}

@Bean
TokenExchangeOAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider() {
TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
authorizedClientProvider.setAccessTokenResponseClient(new MockTokenExchangeClient());
return authorizedClientProvider;
}

}

abstract static class OAuth2ClientBaseConfig {
Expand Down Expand Up @@ -463,6 +519,14 @@ ClientRegistrationRepository clientRegistrationRepository() {
.clientId("okta-client-id")
.clientSecret("okta-client-secret")
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.build(),
ClientRegistration.withRegistrationId("auth0")
.clientName("Auth0")
.clientId("auth0-client-id")
.clientSecret("auth0-client-secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.scope("user.read", "user.write")
.build()));
// @formatter:on
}
Expand Down Expand Up @@ -544,4 +608,13 @@ public OAuth2AccessTokenResponse getTokenResponse(JwtBearerGrantRequest authoriz

}

private static class MockTokenExchangeClient implements OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> {

@Override
public OAuth2AccessTokenResponse getTokenResponse(TokenExchangeGrantRequest authorizationGrantRequest) {
return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,21 @@
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
Expand Down Expand Up @@ -316,6 +319,47 @@ private void testJwtBearerGrant() {
assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user");
}

@Test
public void authorizeWhenTokenExchangeAccessTokenResponseClientBeanThenUsed() {
this.spring.configLocations(xml("clients")).autowire();
testTokenExchangeGrant();
}

@Test
public void authorizeWhenTokenExchangeAuthorizedClientProviderBeanThenUsed() {
this.spring.configLocations(xml("providers")).autowire();
testTokenExchangeGrant();
}

private void testTokenExchangeGrant() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(TokenExchangeGrantRequest.class)))
.willReturn(accessTokenResponse);

JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt());
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("auth0");
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(clientRegistration.getRegistrationId())
.principal(authentication)
.attribute(HttpServletRequest.class.getName(), this.request)
.attribute(HttpServletResponse.class.getName(), this.response)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
assertThat(authorizedClient).isNotNull();

ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
.forClass(TokenExchangeGrantRequest.class);
verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());

TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue();
assertThat(grantRequest.getClientRegistration().getRegistrationId())
.isEqualTo(clientRegistration.getRegistrationId());
assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE);
assertThat(grantRequest.getSubjectToken()).isEqualTo(authentication.getToken());
}

private static OAuth2AccessToken getExpiredAccessToken() {
Instant expiresAt = Instant.now().minusSeconds(60);
Instant issuedAt = expiresAt.minus(Duration.ofDays(1));
Expand Down Expand Up @@ -356,6 +400,14 @@ public static List<ClientRegistration> getClientRegistrations() {
.clientId("okta-client-id")
.clientSecret("okta-client-secret")
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.build(),
ClientRegistration.withRegistrationId("auth0")
.clientName("Auth0")
.clientId("auth0-client-id")
.clientSecret("auth0-client-secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
.scope("user.read", "user.write")
.build());
// @formatter:on
}
Expand Down Expand Up @@ -422,6 +474,16 @@ public static OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> jwtBearerAc
return new MockJwtBearerClient();
}

public static TokenExchangeOAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider() {
TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
authorizedClientProvider.setAccessTokenResponseClient(tokenExchangeAccessTokenResponseClient());
return authorizedClientProvider;
}

public static OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> tokenExchangeAccessTokenResponseClient() {
return new MockTokenExchangeClient();
}

private static class MockAuthorizationCodeClient
implements OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {

Expand Down Expand Up @@ -472,4 +534,13 @@ public OAuth2AccessTokenResponse getTokenResponse(JwtBearerGrantRequest authoriz

}

private static class MockTokenExchangeClient implements OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> {

@Override
public OAuth2AccessTokenResponse getTokenResponse(TokenExchangeGrantRequest authorizationGrantRequest) {
return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@
<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
factory-method="jwtBearerAccessTokenResponseClient"/>

<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
factory-method="tokenExchangeAccessTokenResponseClient"/>

</b:beans>
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,7 @@
<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
factory-method="jwtBearerAuthorizedClientProvider"/>

<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
factory-method="tokenExchangeAuthorizedClientProvider"/>

</b:beans>

0 comments on commit d6382b8

Please sign in to comment.