diff --git a/accept.go b/accept.go index b90e15eb..285b3103 100644 --- a/accept.go +++ b/accept.go @@ -185,10 +185,21 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } - if r.Header.Get("Sec-WebSocket-Key") == "" { + websocketSecKeys := r.Header.Values("Sec-WebSocket-Key") + if len(websocketSecKeys) == 0 { return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } + if len(websocketSecKeys) > 1 { + return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers") + } + + // The RFC states to remove any leading or trailing whitespace. + websocketSecKey := strings.TrimSpace(websocketSecKeys[0]) + if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 { + return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey) + } + return 0, nil } diff --git a/accept_test.go b/accept_test.go index c554bdaf..7cb85d0f 100644 --- a/accept_test.go +++ b/accept_test.go @@ -13,6 +13,7 @@ import ( "testing" "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/xrand" ) func TestAccept(t *testing.T) { @@ -36,7 +37,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) @@ -52,7 +53,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) r.Header.Set("Origin", "https://harhar.com") _, err := Accept(w, r, nil) @@ -67,7 +68,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) r.Header.Set("Sec-WebSocket-Extensions", extensions) return r } @@ -116,7 +117,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) _, err := Accept(w, r, nil) assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) @@ -136,7 +137,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) @@ -183,7 +184,15 @@ func Test_verifyClientHandshake(t *testing.T) { }, }, { - name: "badWebSocketKey", + name: "missingWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + }, + }, + { + name: "emptyWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", @@ -191,13 +200,43 @@ func Test_verifyClientHandshake(t *testing.T) { "Sec-WebSocket-Key": "", }, }, + { + name: "shortWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": xrand.Base64(15), + }, + }, + { + name: "invalidWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "notbase64", + }, + }, + { + name: "extraWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + // Kinda cheeky, but http headers are case-insensitive. + // If 2 sec keys are present, this is a failure condition. + "Sec-WebSocket-Key": xrand.Base64(16), + "sec-webSocket-key": xrand.Base64(16), + }, + }, { name: "badHTTPVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", - "Sec-WebSocket-Key": "meow123", + "Sec-WebSocket-Key": xrand.Base64(16), }, http1: true, }, @@ -207,7 +246,17 @@ func Test_verifyClientHandshake(t *testing.T) { "Connection": "keep-alive, Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", - "Sec-WebSocket-Key": "meow123", + "Sec-WebSocket-Key": xrand.Base64(16), + }, + success: true, + }, + { + name: "successSecKeyExtraSpace", + h: map[string]string{ + "Connection": "keep-alive, Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": " " + xrand.Base64(16) + " ", }, success: true, }, @@ -227,7 +276,7 @@ func Test_verifyClientHandshake(t *testing.T) { } for k, v := range tc.h { - r.Header.Set(k, v) + r.Header.Add(k, v) } _, err := verifyClientRequest(httptest.NewRecorder(), r) diff --git a/internal/test/xrand/xrand.go b/internal/test/xrand/xrand.go index 8de1ede8..9bfb39ce 100644 --- a/internal/test/xrand/xrand.go +++ b/internal/test/xrand/xrand.go @@ -2,6 +2,7 @@ package xrand import ( "crypto/rand" + "encoding/base64" "fmt" "math/big" "strings" @@ -45,3 +46,8 @@ func Int(max int) int { } return int(x.Int64()) } + +// Base64 returns a randomly generated base64 string of length n. +func Base64(n int) string { + return base64.StdEncoding.EncodeToString(Bytes(n)) +}