Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check for Sec-WebSocket-Key header #752

Merged
merged 2 commits into from Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions server.go
Expand Up @@ -154,8 +154,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}

challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
if !isValidChallengeKey(challengeKey) {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
}

subprotocol := u.selectSubprotocol(r, responseHeader)
Expand Down
15 changes: 15 additions & 0 deletions util.go
Expand Up @@ -281,3 +281,18 @@ headers:
}
return result
}

// isValidChallengeKey checks if the argument meets RFC6455 specification.
func isValidChallengeKey(s string) bool {
// From RFC6455:
//
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
// length.

if s == "" {
return false
}
decoded, err := base64.StdEncoding.DecodeString(s)
return err == nil && len(decoded) == 16
}
19 changes: 19 additions & 0 deletions util_test.go
Expand Up @@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) {
}
}

var isValidChallengeKeyTests = []struct {
key string
ok bool
}{
{"dGhlIHNhbXBsZSBub25jZQ==", true},
{"", false},
{"InvalidKey", false},
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
}

func TestIsValidChallengeKey(t *testing.T) {
for _, tt := range isValidChallengeKeyTests {
ok := isValidChallengeKey(tt.key)
if ok != tt.ok {
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
}
}
}

var parseExtensionTests = []struct {
value string
extensions []map[string]string
Expand Down