Skip to content

Commit

Permalink
Server selects first acceptable compression offer
Browse files Browse the repository at this point in the history
Unacceptable offers are declined without rejecting the request.
  • Loading branch information
abursavich committed Sep 9, 2020
1 parent cc2d7bd commit d522d62
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 118 deletions.
77 changes: 17 additions & 60 deletions accept.go
Expand Up @@ -118,9 +118,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}

copts, err := acceptCompression(r, w, opts.CompressionMode)
if err != nil {
return nil, err
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
}

w.WriteHeader(http.StatusSwitchingProtocols)
Expand Down Expand Up @@ -230,26 +230,23 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
return ""
}

func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, nil
return nil, false
}

for _, ext := range websocketExtensions(r.Header) {
for _, ext := range extensions {
switch ext.name {
case "permessage-deflate":
return acceptDeflate(w, ext, mode)
// Disabled for now, see https://github.com/nhooyr/websocket/issues/218
// case "x-webkit-deflate-frame":
// return acceptWebkitDeflate(w, ext, mode)
if copts, ok := acceptDeflate(ext, mode); ok {
return copts, true
}
}
}
return nil, nil
return nil, false
}

func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
copts := mode.opts()

for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
Expand All @@ -258,57 +255,17 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
case "server_max_window_bits=15":
continue
}

if strings.HasPrefix(p, "client_max_window_bits") {
// We cannot adjust the read sliding window so cannot make use of this.
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}

err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}

copts.setHeader(w.Header())

return copts, nil
}

func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
// The peer must explicitly request it.
copts.serverNoContextTakeover = false

for _, p := range ext.params {
if p == "no_context_takeover" {
copts.serverNoContextTakeover = true
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}

// We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
// of ignoring it as the draft spec is unclear. It says the server can ignore it
// but the server has no way of signalling to the client it was ignored as the parameters
// are set one way.
// Thus us ignoring it would make the client think we understood it which would cause issues.
// See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
//
// Either way, we're only implementing this for webkit which never sends the max_window_bits
// parameter so we don't need to worry about it.
err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
return nil, false
}

s := "x-webkit-deflate-frame"
if copts.clientNoContextTakeover {
s += "; no_context_takeover"
}
w.Header().Set("Sec-WebSocket-Extensions", s)

return copts, nil
return copts, true
}

func headerContainsToken(h http.Header, key, token string) bool {
Expand Down
122 changes: 68 additions & 54 deletions accept_test.go
Expand Up @@ -45,20 +45,47 @@ func TestAccept(t *testing.T) {
t.Run("badCompression", func(t *testing.T) {
t.Parallel()

w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
newRequest := func(extensions string) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("hijack error")
},
}
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")

_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionContextTakeover,
t.Run("withoutFallback", func(t *testing.T) {
t.Parallel()

w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar")
_, _ = Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
})
t.Run("withFallback", func(t *testing.T) {
t.Parallel()

w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
_, _ = Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.Equal(t, "extension header",
w.Header().Get("Sec-WebSocket-Extensions"),
CompressionNoContextTakeover.opts().String(),
)
})
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
})

t.Run("requireHttpHijacker", func(t *testing.T) {
Expand Down Expand Up @@ -321,79 +348,66 @@ func Test_authenticateOrigin(t *testing.T) {
}
}

func Test_acceptCompression(t *testing.T) {
func Test_selectDeflate(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
mode CompressionMode
reqSecWebSocketExtensions string
respSecWebSocketExtensions string
expCopts *compressionOptions
error bool
name string
mode CompressionMode
header string
expCopts *compressionOptions
expOK bool
}{
{
name: "disabled",
mode: CompressionDisabled,
expCopts: nil,
expOK: false,
},
{
name: "noClientSupport",
mode: CompressionNoContextTakeover,
expCopts: nil,
expOK: false,
},
{
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow",
expOK: false,
},
{
name: "permessage-deflate/error",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; meow",
error: true,
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
// {
// name: "x-webkit-deflate-frame",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// expCopts: &compressionOptions{
// clientNoContextTakeover: true,
// serverNoContextTakeover: true,
// },
// },
// {
// name: "x-webkit-deflate/error",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
// error: true,
// },
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)

w := httptest.NewRecorder()
copts, err := acceptCompression(r, w, tc.mode)
if tc.error {
assert.Error(t, err)
return
}

assert.Success(t, err)
h := http.Header{}
h.Set("Sec-WebSocket-Extensions", tc.header)
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
assert.Equal(t, "selected options", tc.expOK, ok)
assert.Equal(t, "compression options", tc.expCopts, copts)
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
})
}
}
Expand Down
5 changes: 2 additions & 3 deletions compress.go
Expand Up @@ -5,7 +5,6 @@ package websocket
import (
"compress/flate"
"io"
"net/http"
"sync"
)

Expand Down Expand Up @@ -58,15 +57,15 @@ type compressionOptions struct {
serverNoContextTakeover bool
}

func (copts *compressionOptions) setHeader(h http.Header) {
func (copts *compressionOptions) String() string {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
h.Set("Sec-WebSocket-Extensions", s)
return s
}

// These bytes are required to get flate.Reader to return.
Expand Down
2 changes: 1 addition & 1 deletion dial.go
Expand Up @@ -162,7 +162,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
copts.setHeader(req.Header)
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
}

resp, err := opts.HTTPClient.Do(req)
Expand Down

0 comments on commit d522d62

Please sign in to comment.