Skip to content

Commit

Permalink
Merge pull request #258 from abursavich/compression
Browse files Browse the repository at this point in the history
Make compression negotiation more lenient
  • Loading branch information
nhooyr committed Oct 13, 2023
2 parents 81afa8a + d6b342b commit 97d7f90
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 91 deletions.
41 changes: 19 additions & 22 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}

copts, err := acceptCompression(r, w, opts.CompressionMode)
if err != nil {
return nil, err
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
}

w.WriteHeader(http.StatusSwitchingProtocols)
Expand Down Expand Up @@ -238,25 +238,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
return ""
}

func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, nil
return nil, false
}

for _, ext := range websocketExtensions(r.Header) {
for _, ext := range extensions {
switch ext.name {
// We used to implement x-webkit-deflate-fram too but Safari has bugs.
// See https://github.com/nhooyr/websocket/issues/218
case "permessage-deflate":
return acceptDeflate(w, ext, mode)
copts, ok := acceptDeflate(ext, mode)
if ok {
return copts, true
}
}
}
return nil, nil
return nil, false
}

func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
copts := mode.opts()

for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
Expand All @@ -265,22 +266,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}

if strings.HasPrefix(p, "client_max_window_bits") {
// We cannot adjust the read sliding window so cannot make use of this.
// By not responding to it, we tell the client we're ignoring it.
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}

err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
return nil, false
}

copts.setHeader(w.Header())

return copts, nil
return copts, true
}

func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
Expand Down
125 changes: 71 additions & 54 deletions accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,50 @@ func TestAccept(t *testing.T) {
t.Run("badCompression", func(t *testing.T) {
t.Parallel()

w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
newRequest := func(extensions string) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
errHijack := errors.New("hijack error")
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errHijack
},
}
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")

_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionContextTakeover,
t.Run("withoutFallback", func(t *testing.T) {
t.Parallel()

w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
})
t.Run("withFallback", func(t *testing.T) {
t.Parallel()

w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header",
w.Header().Get("Sec-WebSocket-Extensions"),
CompressionNoContextTakeover.opts().String(),
)
})
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
})

t.Run("requireHttpHijacker", func(t *testing.T) {
Expand Down Expand Up @@ -344,79 +374,66 @@ func Test_authenticateOrigin(t *testing.T) {
}
}

func Test_acceptCompression(t *testing.T) {
func Test_selectDeflate(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
mode CompressionMode
reqSecWebSocketExtensions string
respSecWebSocketExtensions string
expCopts *compressionOptions
error bool
name string
mode CompressionMode
header string
expCopts *compressionOptions
expOK bool
}{
{
name: "disabled",
mode: CompressionDisabled,
expCopts: nil,
expOK: false,
},
{
name: "noClientSupport",
mode: CompressionNoContextTakeover,
expCopts: nil,
expOK: false,
},
{
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow",
expOK: false,
},
{
name: "permessage-deflate/error",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; meow",
error: true,
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
// {
// name: "x-webkit-deflate-frame",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// expCopts: &compressionOptions{
// clientNoContextTakeover: true,
// serverNoContextTakeover: true,
// },
// },
// {
// name: "x-webkit-deflate/error",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
// error: true,
// },
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)

w := httptest.NewRecorder()
copts, err := acceptCompression(r, w, tc.mode)
if tc.error {
assert.Error(t, err)
return
}

assert.Success(t, err)
h := http.Header{}
h.Set("Sec-WebSocket-Extensions", tc.header)
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
assert.Equal(t, "selected options", tc.expOK, ok)
assert.Equal(t, "compression options", tc.expCopts, copts)
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
})
}
}
Expand Down
11 changes: 3 additions & 8 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,13 @@ package websocket
import (
"compress/flate"
"io"
"net/http"
"sync"
)

// CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// A compatibility layer is implemented for the older deflate-frame extension used
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
// It will work the same in every way except that we cannot signal to the peer we
// want to use no context takeover on our side, we can only signal that they should.
// But it is currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
// Works in all browsers except Safari which does not implement the deflate extension.
type CompressionMode int

const (
Expand Down Expand Up @@ -65,15 +60,15 @@ type compressionOptions struct {
serverNoContextTakeover bool
}

func (copts *compressionOptions) setHeader(h http.Header) {
func (copts *compressionOptions) String() string {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
h.Set("Sec-WebSocket-Extensions", s)
return s
}

// These bytes are required to get flate.Reader to return.
Expand Down
6 changes: 5 additions & 1 deletion dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
copts.setHeader(req.Header)
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
}

resp, err := opts.HTTPClient.Do(req)
Expand Down Expand Up @@ -273,6 +273,10 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "server_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}

return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
Expand Down
10 changes: 10 additions & 0 deletions internal/test/assert/assert.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package assert

import (
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -43,3 +44,12 @@ func Contains(t testing.TB, v interface{}, sub string) {
t.Fatalf("expected %q to contain %q", s, sub)
}
}

// ErrorIs asserts errors.Is(got, exp)
func ErrorIs(t testing.TB, exp, got error) {
t.Helper()

if !errors.Is(got, exp) {
t.Fatalf("expected %v but got %v", exp, got)
}
}
7 changes: 1 addition & 6 deletions ws_js.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,7 @@ func CloseStatus(err error) StatusCode {

// CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// A compatibility layer is implemented for the older deflate-frame extension used
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
// It will work the same in every way except that we cannot signal to the peer we
// want to use no context takeover on our side, we can only signal that they should.
// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
// Works in all browsers except Safari which does not implement the deflate extension.
type CompressionMode int

const (
Expand Down

0 comments on commit 97d7f90

Please sign in to comment.