Skip to content

Commit

Permalink
Merge pull request #360 from Emyrk/emyrk/Sec-WebSocket-Key
Browse files Browse the repository at this point in the history
Reject invalid "Sec-WebSocket-Key" headers from clients
  • Loading branch information
nhooyr committed Oct 19, 2023
2 parents 64ce009 + 305eab9 commit 10137fa
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 10 deletions.
13 changes: 12 additions & 1 deletion accept.go
Expand Up @@ -185,10 +185,21 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
}

if r.Header.Get("Sec-WebSocket-Key") == "" {
websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
if len(websocketSecKeys) == 0 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
}

if len(websocketSecKeys) > 1 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
}

// The RFC states to remove any leading or trailing whitespace.
websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
}

return 0, nil
}

Expand Down
67 changes: 58 additions & 9 deletions accept_test.go
Expand Up @@ -13,6 +13,7 @@ import (
"testing"

"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/xrand"
)

func TestAccept(t *testing.T) {
Expand All @@ -36,7 +37,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "harhar.com")

_, err := Accept(w, r, nil)
Expand All @@ -52,7 +53,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "https://harhar.com")

_, err := Accept(w, r, nil)
Expand All @@ -67,7 +68,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
Expand Down Expand Up @@ -116,7 +117,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))

_, err := Accept(w, r, nil)
assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
Expand All @@ -136,7 +137,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))

_, err := Accept(w, r, nil)
assert.Contains(t, err, `failed to hijack connection`)
Expand Down Expand Up @@ -183,21 +184,59 @@ func Test_verifyClientHandshake(t *testing.T) {
},
},
{
name: "badWebSocketKey",
name: "missingWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
},
},
{
name: "emptyWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "",
},
},
{
name: "shortWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": xrand.Base64(15),
},
},
{
name: "invalidWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "notbase64",
},
},
{
name: "extraWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
// Kinda cheeky, but http headers are case-insensitive.
// If 2 sec keys are present, this is a failure condition.
"Sec-WebSocket-Key": xrand.Base64(16),
"sec-webSocket-key": xrand.Base64(16),
},
},
{
name: "badHTTPVersion",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123",
"Sec-WebSocket-Key": xrand.Base64(16),
},
http1: true,
},
Expand All @@ -207,7 +246,17 @@ func Test_verifyClientHandshake(t *testing.T) {
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123",
"Sec-WebSocket-Key": xrand.Base64(16),
},
success: true,
},
{
name: "successSecKeyExtraSpace",
h: map[string]string{
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": " " + xrand.Base64(16) + " ",
},
success: true,
},
Expand All @@ -227,7 +276,7 @@ func Test_verifyClientHandshake(t *testing.T) {
}

for k, v := range tc.h {
r.Header.Set(k, v)
r.Header.Add(k, v)
}

_, err := verifyClientRequest(httptest.NewRecorder(), r)
Expand Down
6 changes: 6 additions & 0 deletions internal/test/xrand/xrand.go
Expand Up @@ -2,6 +2,7 @@ package xrand

import (
"crypto/rand"
"encoding/base64"
"fmt"
"math/big"
"strings"
Expand Down Expand Up @@ -45,3 +46,8 @@ func Int(max int) int {
}
return int(x.Int64())
}

// Base64 returns a randomly generated base64 string of length n.
func Base64(n int) string {
return base64.StdEncoding.EncodeToString(Bytes(n))
}

0 comments on commit 10137fa

Please sign in to comment.