Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make compression negotiation more lenient #258

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 19 additions & 22 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}

copts, err := acceptCompression(r, w, opts.CompressionMode)
if err != nil {
return nil, err
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
}

w.WriteHeader(http.StatusSwitchingProtocols)
Expand Down Expand Up @@ -238,25 +238,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
return ""
}

func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, nil
return nil, false
}

for _, ext := range websocketExtensions(r.Header) {
for _, ext := range extensions {
switch ext.name {
// We used to implement x-webkit-deflate-fram too but Safari has bugs.
// See https://github.com/nhooyr/websocket/issues/218
case "permessage-deflate":
return acceptDeflate(w, ext, mode)
copts, ok := acceptDeflate(ext, mode)
if ok {
return copts, true
}
}
}
return nil, nil
return nil, false
}

func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
copts := mode.opts()

for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
Expand All @@ -265,22 +266,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}

if strings.HasPrefix(p, "client_max_window_bits") {
// We cannot adjust the read sliding window so cannot make use of this.
// By not responding to it, we tell the client we're ignoring it.
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}

err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
return nil, false
}

copts.setHeader(w.Header())

return copts, nil
return copts, true
}

func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
Expand Down
125 changes: 71 additions & 54 deletions accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,50 @@ func TestAccept(t *testing.T) {
t.Run("badCompression", func(t *testing.T) {
t.Parallel()

w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
newRequest := func(extensions string) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
errHijack := errors.New("hijack error")
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errHijack
},
}
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")

_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionContextTakeover,
t.Run("withoutFallback", func(t *testing.T) {
t.Parallel()

w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
})
t.Run("withFallback", func(t *testing.T) {
t.Parallel()

w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header",
w.Header().Get("Sec-WebSocket-Extensions"),
CompressionNoContextTakeover.opts().String(),
)
})
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
})

t.Run("requireHttpHijacker", func(t *testing.T) {
Expand Down Expand Up @@ -344,79 +374,66 @@ func Test_authenticateOrigin(t *testing.T) {
}
}

func Test_acceptCompression(t *testing.T) {
func Test_selectDeflate(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
mode CompressionMode
reqSecWebSocketExtensions string
respSecWebSocketExtensions string
expCopts *compressionOptions
error bool
name string
mode CompressionMode
header string
expCopts *compressionOptions
expOK bool
}{
{
name: "disabled",
mode: CompressionDisabled,
expCopts: nil,
expOK: false,
},
{
name: "noClientSupport",
mode: CompressionNoContextTakeover,
expCopts: nil,
expOK: false,
},
{
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow",
expOK: false,
},
{
name: "permessage-deflate/error",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; meow",
error: true,
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
// {
// name: "x-webkit-deflate-frame",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// expCopts: &compressionOptions{
// clientNoContextTakeover: true,
// serverNoContextTakeover: true,
// },
// },
// {
// name: "x-webkit-deflate/error",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
// error: true,
// },
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)

w := httptest.NewRecorder()
copts, err := acceptCompression(r, w, tc.mode)
if tc.error {
assert.Error(t, err)
return
}

assert.Success(t, err)
h := http.Header{}
h.Set("Sec-WebSocket-Extensions", tc.header)
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
assert.Equal(t, "selected options", tc.expOK, ok)
assert.Equal(t, "compression options", tc.expCopts, copts)
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
})
}
}
Expand Down
11 changes: 3 additions & 8 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,13 @@ package websocket
import (
"compress/flate"
"io"
"net/http"
"sync"
)

// CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// A compatibility layer is implemented for the older deflate-frame extension used
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
// It will work the same in every way except that we cannot signal to the peer we
// want to use no context takeover on our side, we can only signal that they should.
// But it is currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
// Works in all browsers except Safari which does not implement the deflate extension.
type CompressionMode int

const (
Expand Down Expand Up @@ -65,15 +60,15 @@ type compressionOptions struct {
serverNoContextTakeover bool
}

func (copts *compressionOptions) setHeader(h http.Header) {
func (copts *compressionOptions) String() string {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
h.Set("Sec-WebSocket-Extensions", s)
return s
}

// These bytes are required to get flate.Reader to return.
Expand Down
6 changes: 5 additions & 1 deletion dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
copts.setHeader(req.Header)
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
}

resp, err := opts.HTTPClient.Do(req)
Expand Down Expand Up @@ -273,6 +273,10 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "server_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}

return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
Expand Down
10 changes: 10 additions & 0 deletions internal/test/assert/assert.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package assert

import (
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -43,3 +44,12 @@ func Contains(t testing.TB, v interface{}, sub string) {
t.Fatalf("expected %q to contain %q", s, sub)
}
}

// ErrorIs asserts errors.Is(got, exp)
func ErrorIs(t testing.TB, exp, got error) {
t.Helper()

if !errors.Is(got, exp) {
t.Fatalf("expected %v but got %v", exp, got)
}
}
7 changes: 1 addition & 6 deletions ws_js.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,7 @@ func CloseStatus(err error) StatusCode {

// CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// A compatibility layer is implemented for the older deflate-frame extension used
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
// It will work the same in every way except that we cannot signal to the peer we
// want to use no context takeover on our side, we can only signal that they should.
// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
// Works in all browsers except Safari which does not implement the deflate extension.
type CompressionMode int

const (
Expand Down