Skip to content

Commit

Permalink
relay_state should not be included in signing calculation when it is …
Browse files Browse the repository at this point in the history
…null

Closes gh-13913
  • Loading branch information
marcusdacoregio committed Oct 19, 2023
1 parent 19c4e42 commit 70ad3bf
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,12 @@ <T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest requ
.samlRequest(deflatedAndEncoded)
.relayState(relayState);
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
Map<String, String> parameters = OpenSamlSigningUtils.sign(registration)
.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)
.param(Saml2ParameterNames.RELAY_STATE, relayState)
.parameters();
OpenSamlSigningUtils.QueryParametersPartial parametersPartial = OpenSamlSigningUtils.sign(registration)
.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded);
if (relayState != null) {
parametersPartial = parametersPartial.param(Saml2ParameterNames.RELAY_STATE, relayState);
}
Map<String, String> parameters = parametersPartial.parameters();
builder.sigAlg(parameters.get(Saml2ParameterNames.SIG_ALG))
.signature(parameters.get(Saml2ParameterNames.SIGNATURE));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Answers;
import org.mockito.MockedStatic;
import org.opensaml.xmlsec.signature.support.SignatureConstants;

import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
Expand All @@ -32,6 +35,12 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;

/**
* Tests for {@link OpenSamlAuthenticationRequestResolver}
Expand Down Expand Up @@ -103,8 +112,8 @@ public void resolveAuthenticationRequestWhenSignedThenCredentialIsRequired() {
.build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
assertThatExceptionOfType(Saml2Exception.class)
.isThrownBy(() -> resolver.resolve(request, (r, authnRequest) -> {
}));
.isThrownBy(() -> resolver.resolve(request, (r, authnRequest) -> {
}));
}

@Test
Expand Down Expand Up @@ -172,6 +181,58 @@ public void resolveAuthenticationRequestWhenSHA1SignRequestThenSigns() {
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
}

@Test
public void resolveAuthenticationRequestWhenSignedAndRelayStateIsNullThenSignsWithoutRelayState() {
try (MockedStatic<OpenSamlSigningUtils> openSamlSigningUtilsMockedStatic = mockStatic(
OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setPathInfo("/saml2/authenticate/registration-id");
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
.assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true))
.build();
OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy(
new OpenSamlSigningUtils.QueryParametersPartial(registration));
openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any()))
.thenReturn(queryParametersPartialSpy);
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
resolver.setRelayStateResolver((source) -> null);
Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
});
assertThat(result.getSamlRequest()).isNotEmpty();
assertThat(result.getRelayState()).isNull();
assertThat(result.getSigAlg()).isNotNull();
assertThat(result.getSignature()).isNotNull();
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
verify(queryParametersPartialSpy, never()).param(eq(Saml2ParameterNames.RELAY_STATE), any());
}
}

@Test
public void resolveAuthenticationRequestWhenSignedAndRelayStateIsEmptyThenSignsWithEmptyRelayState() {
try (MockedStatic<OpenSamlSigningUtils> openSamlSigningUtilsMockedStatic = mockStatic(
OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setPathInfo("/saml2/authenticate/registration-id");
RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
.assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true))
.build();
OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy(
new OpenSamlSigningUtils.QueryParametersPartial(registration));
openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any()))
.thenReturn(queryParametersPartialSpy);
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
resolver.setRelayStateResolver((source) -> "");
Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
});
assertThat(result.getSamlRequest()).isNotEmpty();
assertThat(result.getRelayState()).isEmpty();
assertThat(result.getSigAlg()).isNotNull();
assertThat(result.getSignature()).isNotNull();
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
verify(queryParametersPartialSpy).param(eq(Saml2ParameterNames.RELAY_STATE), eq(""));
}
}

private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) {
return new OpenSamlAuthenticationRequestResolver((request, id) -> registration);
}
Expand Down

0 comments on commit 70ad3bf

Please sign in to comment.