diff --git a/accept.go b/accept.go index e9691699..1b46428c 100644 --- a/accept.go +++ b/accept.go @@ -118,9 +118,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionMode) - if err != nil { - return nil, err + copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) + if ok { + w.Header().Set("Sec-WebSocket-Extensions", copts.String()) } w.WriteHeader(http.StatusSwitchingProtocols) @@ -230,26 +230,23 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string { return "" } -func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { +func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { if mode == CompressionDisabled { - return nil, nil + return nil, false } - - for _, ext := range websocketExtensions(r.Header) { + for _, ext := range extensions { switch ext.name { case "permessage-deflate": - return acceptDeflate(w, ext, mode) - // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 - // case "x-webkit-deflate-frame": - // return acceptWebkitDeflate(w, ext, mode) + if copts, ok := acceptDeflate(ext, mode); ok { + return copts, true + } } } - return nil, nil + return nil, false } -func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { +func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { copts := mode.opts() - for _, p := range ext.params { switch p { case "client_no_context_takeover": @@ -258,57 +255,17 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi case "server_no_context_takeover": copts.serverNoContextTakeover = true continue - case "server_max_window_bits=15": - continue - } - - if strings.HasPrefix(p, "client_max_window_bits") { - // We cannot adjust the read sliding window so cannot make use of this. + case "client_max_window_bits", + "server_max_window_bits=15": continue } - - err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err - } - - copts.setHeader(w.Header()) - - return copts, nil -} - -func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { - copts := mode.opts() - // The peer must explicitly request it. - copts.serverNoContextTakeover = false - - for _, p := range ext.params { - if p == "no_context_takeover" { - copts.serverNoContextTakeover = true + if strings.HasPrefix(p, "client_max_window_bits=") { + // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } - - // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead - // of ignoring it as the draft spec is unclear. It says the server can ignore it - // but the server has no way of signalling to the client it was ignored as the parameters - // are set one way. - // Thus us ignoring it would make the client think we understood it which would cause issues. - // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 - // - // Either way, we're only implementing this for webkit which never sends the max_window_bits - // parameter so we don't need to worry about it. - err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err + return nil, false } - - s := "x-webkit-deflate-frame" - if copts.clientNoContextTakeover { - s += "; no_context_takeover" - } - w.Header().Set("Sec-WebSocket-Extensions", s) - - return copts, nil + return copts, true } func headerContainsToken(h http.Header, key, token string) bool { diff --git a/accept_test.go b/accept_test.go index f7bc6693..f0233931 100644 --- a/accept_test.go +++ b/accept_test.go @@ -45,20 +45,47 @@ func TestAccept(t *testing.T) { t.Run("badCompression", func(t *testing.T) { t.Parallel() - w := mockHijacker{ - ResponseWriter: httptest.NewRecorder(), + newRequest := func(extensions string) *http.Request { + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Extensions", extensions) + return r + } + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("hijack error") + }, + } } - r := httptest.NewRequest("GET", "/", nil) - r.Header.Set("Connection", "Upgrade") - r.Header.Set("Upgrade", "websocket") - r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") - r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") - _, err := Accept(w, r, &AcceptOptions{ - CompressionMode: CompressionContextTakeover, + t.Run("withoutFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar") + _, _ = Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "") + }) + t.Run("withFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar, permessage-deflate") + _, _ = Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.Equal(t, "extension header", + w.Header().Get("Sec-WebSocket-Extensions"), + CompressionNoContextTakeover.opts().String(), + ) }) - assert.Contains(t, err, `unsupported permessage-deflate parameter`) }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -321,59 +348,54 @@ func Test_authenticateOrigin(t *testing.T) { } } -func Test_acceptCompression(t *testing.T) { +func Test_selectDeflate(t *testing.T) { t.Parallel() testCases := []struct { - name string - mode CompressionMode - reqSecWebSocketExtensions string - respSecWebSocketExtensions string - expCopts *compressionOptions - error bool + name string + mode CompressionMode + header string + expCopts *compressionOptions + expOK bool }{ { name: "disabled", mode: CompressionDisabled, expCopts: nil, + expOK: false, }, { name: "noClientSupport", mode: CompressionNoContextTakeover, expCopts: nil, + expOK: false, }, { - name: "permessage-deflate", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", - respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; client_max_window_bits", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, + expOK: true, + }, + { + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow", + expOK: false, }, { - name: "permessage-deflate/error", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; meow", - error: true, + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + expOK: true, }, - // { - // name: "x-webkit-deflate-frame", - // mode: CompressionNoContextTakeover, - // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", - // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", - // expCopts: &compressionOptions{ - // clientNoContextTakeover: true, - // serverNoContextTakeover: true, - // }, - // }, - // { - // name: "x-webkit-deflate/error", - // mode: CompressionNoContextTakeover, - // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", - // error: true, - // }, } for _, tc := range testCases { @@ -381,19 +403,11 @@ func Test_acceptCompression(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) - - w := httptest.NewRecorder() - copts, err := acceptCompression(r, w, tc.mode) - if tc.error { - assert.Error(t, err) - return - } - - assert.Success(t, err) + h := http.Header{} + h.Set("Sec-WebSocket-Extensions", tc.header) + copts, ok := selectDeflate(websocketExtensions(h), tc.mode) + assert.Equal(t, "selected options", tc.expOK, ok) assert.Equal(t, "compression options", tc.expCopts, copts) - assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } diff --git a/compress.go b/compress.go index f49d9e5d..85fbc6e8 100644 --- a/compress.go +++ b/compress.go @@ -5,7 +5,6 @@ package websocket import ( "compress/flate" "io" - "net/http" "sync" ) @@ -58,7 +57,7 @@ type compressionOptions struct { serverNoContextTakeover bool } -func (copts *compressionOptions) setHeader(h http.Header) { +func (copts *compressionOptions) String() string { s := "permessage-deflate" if copts.clientNoContextTakeover { s += "; client_no_context_takeover" @@ -66,7 +65,7 @@ func (copts *compressionOptions) setHeader(h http.Header) { if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } - h.Set("Sec-WebSocket-Extensions", s) + return s } // These bytes are required to get flate.Reader to return. diff --git a/dial.go b/dial.go index f31f690e..a5f852c6 100644 --- a/dial.go +++ b/dial.go @@ -162,7 +162,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { - copts.setHeader(req.Header) + req.Header.Set("Sec-WebSocket-Extensions", copts.String()) } resp, err := opts.HTTPClient.Do(req)