From 98ff542abe3108aa760c1558f80d393be0136539 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 29 Nov 2023 05:23:59 -0800 Subject: [PATCH] gzhttp: Allow overriding decompression on transport (#892) This allows getting compressed data even if `Content-Encoding` is set. Also allows decompression even if "Accept-Encoding" was not set by this client. --- gzhttp/transport.go | 25 ++++++++++++++++++---- gzhttp/transport_test.go | 45 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/gzhttp/transport.go b/gzhttp/transport.go index a199fbc6e8..623aea2ed8 100644 --- a/gzhttp/transport.go +++ b/gzhttp/transport.go @@ -14,7 +14,7 @@ import ( "github.com/klauspost/compress/zstd" ) -// Transport will wrap a transport with a custom handler +// Transport will wrap an HTTP transport with a custom handler // that will request gzip and automatically decompress it. // Using this is significantly faster than using the default transport. func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper { @@ -51,10 +51,21 @@ func TransportEnableGzip(b bool) transportOption { } } +// TransportCustomEval will send the header of a response to a custom function. +// If the function returns false, the response will be returned as-is, +// Otherwise it will be decompressed based on Content-Encoding field, regardless +// of whether the transport added the encoding. +func TransportCustomEval(fn func(header http.Header) bool) transportOption { + return func(c *gzRoundtripper) { + c.customEval = fn + } +} + type gzRoundtripper struct { parent http.RoundTripper acceptEncoding string withZstd, withGzip bool + customEval func(header http.Header) bool } func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -82,16 +93,22 @@ func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil || !requestedComp { return resp, err } - + decompress := false + if g.customEval != nil { + if !g.customEval(resp.Header) { + return resp, nil + } + decompress = true + } // Decompress - if g.withGzip && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + if (decompress || g.withGzip) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") { resp.Body = &gzipReader{body: resp.Body} resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") resp.ContentLength = -1 resp.Uncompressed = true } - if g.withZstd && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") { + if (decompress || g.withZstd) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") { resp.Body = &zstdReader{body: resp.Body} resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") diff --git a/gzhttp/transport_test.go b/gzhttp/transport_test.go index a059ac1b66..aff7edb4cf 100644 --- a/gzhttp/transport_test.go +++ b/gzhttp/transport_test.go @@ -206,6 +206,51 @@ func TestDefaultTransport(t *testing.T) { } } +func TestTransportCustomEval(t *testing.T) { + bin, err := os.ReadFile("testdata/benchmark.json") + if err != nil { + t.Fatal(err) + } + + server := httptest.NewServer(newTestHandler(bin)) + calledWith := "" + c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableZstd(false), TransportCustomEval(func(h http.Header) bool { + calledWith = h.Get("Content-Encoding") + return true + }))} + resp, err := c.Get(server.URL) + if err != nil { + t.Fatal(err) + } + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, bin) { + t.Errorf("data mismatch") + } + if calledWith != "gzip" { + t.Fatalf("Expected encoding %q, got %q", "gzip", calledWith) + } + // Test returning false + c = http.Client{Transport: Transport(http.DefaultTransport, TransportCustomEval(func(h http.Header) bool { + calledWith = h.Get("Content-Encoding") + return false + }))} + resp, err = c.Get(server.URL) + if err != nil { + t.Fatal(err) + } + // Check we got the compressed data + gotCE := resp.Header.Get("Content-Encoding") + if gotCE != "gzip" { + t.Fatalf("Expected encoding %q, got %q", "gzip", gotCE) + } + if calledWith != "gzip" { + t.Fatalf("Expected encoding %q, got %q", "gzip", calledWith) + } +} + func BenchmarkTransport(b *testing.B) { raw, err := os.ReadFile("testdata/benchmark.json") if err != nil {