Skip to content

Commit

Permalink
Do not cache Content-Type in ContentCachingResponseWrapper
Browse files Browse the repository at this point in the history
Based on feedback from several members of the community, we have
decided to revert the caching of the Content-Type header that was
introduced in ContentCachingResponseWrapper in 375e0e6.

This commit therefore completely removes Content-Type caching in
ContentCachingResponseWrapper and updates the existing tests
accordingly.

To provide guards against future regressions in this area, this commit
also introduces explicit tests for the 6 ways to set the content length
in ContentCachingResponseWrapper and modifies a test in
ShallowEtagHeaderFilterTests to ensure that a Content-Type header set
directly on ContentCachingResponseWrapper is propagated to the
underlying response even if content caching is disabled for the
ShallowEtagHeaderFilter.

See gh-32039
Closes gh-32317
  • Loading branch information
sbrannen committed Feb 26, 2024
1 parent 497aa3c commit cca440e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@Nullable
private Integer contentLength;

@Nullable
private String contentType;


/**
* Create a new ContentCachingResponseWrapper for the given servlet response.
Expand Down Expand Up @@ -150,28 +147,11 @@ public void setContentLengthLong(long len) {
setContentLength((int) len);
}

@Override
public void setContentType(@Nullable String type) {
this.contentType = type;
}

@Override
@Nullable
public String getContentType() {
if (this.contentType != null) {
return this.contentType;
}
return super.getContentType();
}

@Override
public boolean containsHeader(String name) {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return true;
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return true;
}
else {
return super.containsHeader(name);
}
Expand All @@ -182,9 +162,6 @@ public void setHeader(String name, String value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.setHeader(name, value);
}
Expand All @@ -195,9 +172,6 @@ public void addHeader(String name, String value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.addHeader(name, value);
}
Expand Down Expand Up @@ -229,9 +203,6 @@ public String getHeader(String name) {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return this.contentLength.toString();
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType;
}
else {
return super.getHeader(name);
}
Expand All @@ -242,9 +213,6 @@ public Collection<String> getHeaders(String name) {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return Collections.singleton(this.contentLength.toString());
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return Collections.singleton(this.contentType);
}
else {
return super.getHeaders(name);
}
Expand All @@ -253,14 +221,9 @@ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(n
@Override
public Collection<String> getHeaderNames() {
Collection<String> headerNames = super.getHeaderNames();
if (this.contentLength != null || this.contentType != null) {
if (this.contentLength != null) {
Set<String> result = new LinkedHashSet<>(headerNames);
if (this.contentLength != null) {
result.add(HttpHeaders.CONTENT_LENGTH);
}
if (this.contentType != null) {
result.add(HttpHeaders.CONTENT_TYPE);
}
result.add(HttpHeaders.CONTENT_LENGTH);
return result;
}
else {
Expand Down Expand Up @@ -333,10 +296,6 @@ protected void copyBodyToResponse(boolean complete) throws IOException {
}
this.contentLength = null;
}
if (this.contentType != null) {
rawResponse.setContentType(this.contentType);
this.contentType = null;
}
}
this.content.writeTo(rawResponse.getOutputStream());
this.content.reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

package org.springframework.web.filter;

import java.util.function.BiConsumer;
import java.util.stream.Stream;

import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import org.springframework.http.MediaType;
Expand All @@ -33,7 +32,6 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Named.named;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
import static org.springframework.http.HttpHeaders.CONTENT_TYPE;
import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING;
Expand Down Expand Up @@ -120,39 +118,83 @@ void copyBodyToResponseWithPresetHeaders() throws Exception {
}

@ParameterizedTest(name = "[{index}] {0}")
@MethodSource("setContentTypeFunctions")
void copyBodyToResponseWithOverridingHeaders(BiConsumer<HttpServletResponse, String> setContentType) throws Exception {
@MethodSource("setContentLengthFunctions")
void copyBodyToResponseWithOverridingContentLength(SetContentLength setContentLength) throws Exception {
byte[] responseBody = "Hello World".getBytes(UTF_8);
int responseLength = responseBody.length;
int originalContentLength = 11;
int overridingContentLength = 22;
String originalContentType = MediaType.TEXT_PLAIN_VALUE;
String overridingContentType = MediaType.APPLICATION_JSON_VALUE;

MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentLength(originalContentLength);
response.setContentType(originalContentType);

ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);
responseWrapper.setStatus(HttpServletResponse.SC_CREATED);
responseWrapper.setContentLength(overridingContentLength);
setContentType.accept(responseWrapper, overridingContentType);

assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
setContentLength.invoke(responseWrapper, overridingContentLength);

assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);

assertHeader(response, CONTENT_LENGTH, originalContentLength);
assertHeader(responseWrapper, CONTENT_LENGTH, overridingContentLength);

FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength);

responseWrapper.copyBodyToResponse();

assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);

assertHeader(response, CONTENT_LENGTH, responseLength);
assertHeader(responseWrapper, CONTENT_LENGTH, responseLength);

assertThat(response.getContentLength()).isEqualTo(responseLength);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);
}

private static Stream<Named<SetContentLength>> setContentLengthFunctions() {
return Stream.of(
named("setContentLength()", HttpServletResponse::setContentLength),
named("setContentLengthLong()", HttpServletResponse::setContentLengthLong),
named("setIntHeader()", (response, contentLength) -> response.setIntHeader(CONTENT_LENGTH, contentLength)),
named("addIntHeader()", (response, contentLength) -> response.addIntHeader(CONTENT_LENGTH, contentLength)),
named("setHeader()", (response, contentLength) -> response.setHeader(CONTENT_LENGTH, "" + contentLength)),
named("addHeader()", (response, contentLength) -> response.addHeader(CONTENT_LENGTH, "" + contentLength))
);
}

@ParameterizedTest(name = "[{index}] {0}")
@MethodSource("setContentTypeFunctions")
void copyBodyToResponseWithOverridingContentType(SetContentType setContentType) throws Exception {
byte[] responseBody = "Hello World".getBytes(UTF_8);
int responseLength = responseBody.length;
String originalContentType = MediaType.TEXT_PLAIN_VALUE;
String overridingContentType = MediaType.APPLICATION_JSON_VALUE;

MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentType(originalContentType);

ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);

assertContentTypeHeader(response, originalContentType);
assertContentTypeHeader(responseWrapper, originalContentType);

setContentType.invoke(responseWrapper, overridingContentType);

assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE);

assertContentTypeHeader(response, overridingContentType);
assertContentTypeHeader(responseWrapper, overridingContentType);

FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength);

responseWrapper.copyBodyToResponse();

assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);

Expand All @@ -161,24 +203,19 @@ void copyBodyToResponseWithOverridingHeaders(BiConsumer<HttpServletResponse, Str
assertContentTypeHeader(response, overridingContentType);
assertContentTypeHeader(responseWrapper, overridingContentType);

assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
assertThat(response.getContentLength()).isEqualTo(responseLength);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
}

private static Stream<Arguments> setContentTypeFunctions() {
private static Stream<Named<SetContentType>> setContentTypeFunctions() {
return Stream.of(
namedArguments("setContentType()", HttpServletResponse::setContentType),
namedArguments("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)),
namedArguments("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType))
named("setContentType()", HttpServletResponse::setContentType),
named("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)),
named("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType))
);
}

private static Arguments namedArguments(String name, BiConsumer<HttpServletResponse, String> setContentTypeFunction) {
return arguments(named(name, setContentTypeFunction));
}

@Test
void copyBodyToResponseWithTransferEncoding() throws Exception {
byte[] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n".getBytes(UTF_8);
Expand Down Expand Up @@ -218,4 +255,15 @@ private void assertContentTypeHeader(HttpServletResponse response, String conten
assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType);
}


@FunctionalInterface
private interface SetContentLength {
void invoke(HttpServletResponse response, int contentLength);
}

@FunctionalInterface
private interface SetContentType {
void invoke(HttpServletResponse response, String contentType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;
import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE;

/**
Expand All @@ -36,6 +37,7 @@
* @author Arjen Poutsma
* @author Brian Clozel
* @author Juergen Hoeller
* @author Sam Brannen
*/
class ShallowEtagHeaderFilterTests {

Expand Down Expand Up @@ -123,7 +125,7 @@ void filterMatch() throws Exception {
assertThat(response.getStatus()).as("Invalid status").isEqualTo(304);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.containsHeader("Content-Length")).as("Response has Content-Length header").isFalse();
assertThat(response.containsHeader("Content-Type")).as("Response has Content-Type header").isFalse();
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(TEXT_PLAIN_VALUE);
assertThat(response.getContentAsByteArray()).as("Invalid content").isEmpty();
}

Expand Down Expand Up @@ -173,11 +175,13 @@ void filterWriter() throws Exception {
public void filterWriterWithDisabledCaching() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels");
MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentType(TEXT_PLAIN_VALUE);

byte[] responseBody = "Hello World".getBytes(UTF_8);
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
filterResponse.setContentType(APPLICATION_JSON_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};

Expand All @@ -186,6 +190,7 @@ public void filterWriterWithDisabledCaching() throws Exception {

assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getHeader("ETag")).isNull();
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(APPLICATION_JSON_VALUE);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
}

Expand Down

0 comments on commit cca440e

Please sign in to comment.