diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index d2eae33e..00000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1 +0,0 @@ -* @nhooyr diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml deleted file mode 100644 index 3d9829ef..00000000 --- a/.github/workflows/ci.yaml +++ /dev/null @@ -1,39 +0,0 @@ -name: ci - -on: [push, pull_request] - -jobs: - fmt: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v1 - - name: Run ./ci/fmt.sh - uses: ./ci/container - with: - args: ./ci/fmt.sh - - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v1 - - name: Run ./ci/lint.sh - uses: ./ci/container - with: - args: ./ci/lint.sh - - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v1 - - name: Run ./ci/test.sh - uses: ./ci/container - with: - args: ./ci/test.sh - env: - NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }} - NETLIFY_SITE_ID: 9b3ee4dc-8297-4774-b4b9-a61561fbbce7 - - name: Upload coverage.html - uses: actions/upload-artifact@v2 - with: - name: coverage.html - path: ./ci/out/coverage.html diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..3c650580 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,47 @@ +name: ci +on: [push, pull_request] +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +jobs: + fmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: ./go.mod + - run: ./ci/fmt.sh + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: go version + - uses: actions/setup-go@v4 + with: + go-version-file: ./go.mod + - run: ./ci/lint.sh + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: ./go.mod + - run: ./ci/test.sh + - uses: actions/upload-artifact@v3 + with: + name: coverage.html + path: ./ci/out/coverage.html + + bench: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: ./go.mod + - run: ./ci/bench.sh diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml new file mode 100644 index 00000000..b1e64fbc --- /dev/null +++ b/.github/workflows/daily.yml @@ -0,0 +1,54 @@ +name: daily +on: + workflow_dispatch: + schedule: + - cron: '42 0 * * *' # daily at 00:42 +concurrency: + group: ${{ github.workflow }} + cancel-in-progress: true + +jobs: + bench: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: ./go.mod + - run: AUTOBAHN=1 ./ci/bench.sh + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: ./go.mod + - run: AUTOBAHN=1 ./ci/test.sh + - uses: actions/upload-artifact@v3 + with: + name: coverage.html + path: ./ci/out/coverage.html + bench-dev: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: dev + - uses: actions/setup-go@v4 + with: + go-version-file: ./go.mod + - run: AUTOBAHN=1 ./ci/bench.sh + test-dev: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: dev + - uses: actions/setup-go@v4 + with: + go-version-file: ./go.mod + - run: AUTOBAHN=1 ./ci/test.sh + - uses: actions/upload-artifact@v3 + with: + name: coverage.html + path: ./ci/out/coverage.html diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 6961e5c8..00000000 --- a/.gitignore +++ /dev/null @@ -1 +0,0 @@ -websocket.test diff --git a/LICENSE.txt b/LICENSE.txt index b5b5fef3..77b5bef6 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,21 +1,13 @@ -MIT License - -Copyright (c) 2018 Anmol Sethi - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +Copyright (c) 2023 Anmol Sethi + +Permission to use, copy, modify, and distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/README.md b/README.md index 15e3c8e6..1c5751d8 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,13 @@ # websocket [![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) -[![coverage](https://img.shields.io/badge/coverage-88%25-success)](https://nhooyrio-websocket-coverage.netlify.app) +[![coverage](https://img.shields.io/badge/coverage-91%25-success)](https://nhooyr.io/websocket/coverage.html) websocket is a minimal and idiomatic WebSocket library for Go. -> **note**: I haven't been responsive for questions/reports on the issue tracker but I do -> read through and there are no outstanding bugs. There are certainly some nice to haves -> that I should merge in/figure out but nothing critical. I haven't given up on adding new -> features and cleaning up the code further, just been busy. Should anything critical -> arise, I will fix it. - ## Install -```bash +```sh go get nhooyr.io/websocket ``` @@ -22,26 +16,37 @@ go get nhooyr.io/websocket - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) -- [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) -- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages +- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- JSON helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage - Zero alloc reads and writes - Concurrent writes - [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) - [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression +- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections - Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) ## Roadmap +See GitHub issues for minor issues but the major future enhancements are: + +- [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217) +- [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340) +- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267) +- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246) +- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209) +- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16) + - WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) +- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402) ## Examples For a production quality example that demonstrates the complete API, see the -[echo example](./examples/echo). +[echo example](./internal/examples/echo). -For a full stack example, see the [chat example](./examples/chat). +For a full stack example, see the [chat example](./internal/examples/chat). ### Server @@ -51,7 +56,7 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { if err != nil { // ... } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() @@ -78,7 +83,7 @@ c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) if err != nil { // ... } -defer c.Close(websocket.StatusInternalError, "the sky is falling") +defer c.CloseNow() err = wsjson.Write(ctx, c, "hi") if err != nil { @@ -113,14 +118,13 @@ Advantages of nhooyr.io/websocket: - Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - Gorilla requires registering a pong callback before sending a Ping - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) -- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages +- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). + Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326) - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) -- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) -- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) +- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) #### golang.org/x/net/websocket @@ -135,4 +139,15 @@ to nhooyr.io/websocket. [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). -However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. +However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws + +When writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. + +#### lesismal/nbio + +[lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is +event driven for performance reasons. + +However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio + +When writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. diff --git a/accept.go b/accept.go index 18536bdb..285b3103 100644 --- a/accept.go +++ b/accept.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -51,7 +52,7 @@ type AcceptOptions struct { OriginPatterns []string // CompressionMode controls the compression mode. - // Defaults to CompressionNoContextTakeover. + // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode @@ -63,6 +64,14 @@ type AcceptOptions struct { CompressionThreshold int } +func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { + var o AcceptOptions + if opts != nil { + o = *opts + } + return &o +} + // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to a WebSocket. // @@ -77,17 +86,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") - if opts == nil { - opts = &AcceptOptions{} - } - opts = &*opts - errCode, err := verifyClientRequest(w, r) if err != nil { http.Error(w, err.Error(), errCode) return nil, err } + opts = opts.cloneWithDefaults() if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { @@ -118,9 +123,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionMode) - if err != nil { - return nil, err + copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) + if ok { + w.Header().Set("Sec-WebSocket-Extensions", copts.String()) } w.WriteHeader(http.StatusSwitchingProtocols) @@ -180,10 +185,21 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } - if r.Header.Get("Sec-WebSocket-Key") == "" { + websocketSecKeys := r.Header.Values("Sec-WebSocket-Key") + if len(websocketSecKeys) == 0 { return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } + if len(websocketSecKeys) > 1 { + return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers") + } + + // The RFC states to remove any leading or trailing whitespace. + websocketSecKey := strings.TrimSpace(websocketSecKeys[0]) + if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 { + return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey) + } + return 0, nil } @@ -211,7 +227,10 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { return nil } } - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + if u.Host == "" { + return fmt.Errorf("request Origin %q is not a valid URL with a host", origin) + } + return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host) } func match(pattern, s string) (bool, error) { @@ -230,26 +249,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string { return "" } -func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { +func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { if mode == CompressionDisabled { - return nil, nil + return nil, false } - - for _, ext := range websocketExtensions(r.Header) { + for _, ext := range extensions { switch ext.name { + // We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs... + // See https://github.com/nhooyr/websocket/issues/218 case "permessage-deflate": - return acceptDeflate(w, ext, mode) - // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 - // case "x-webkit-deflate-frame": - // return acceptWebkitDeflate(w, ext, mode) + copts, ok := acceptDeflate(ext, mode) + if ok { + return copts, true + } } } - return nil, nil + return nil, false } -func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { +func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { copts := mode.opts() - for _, p := range ext.params { switch p { case "client_no_context_takeover": @@ -258,55 +277,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi case "server_no_context_takeover": copts.serverNoContextTakeover = true continue - } - - if strings.HasPrefix(p, "client_max_window_bits") { - // We cannot adjust the read sliding window so cannot make use of this. + case "client_max_window_bits", + "server_max_window_bits=15": continue } - err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err - } - - copts.setHeader(w.Header()) - - return copts, nil -} - -func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { - copts := mode.opts() - // The peer must explicitly request it. - copts.serverNoContextTakeover = false - - for _, p := range ext.params { - if p == "no_context_takeover" { - copts.serverNoContextTakeover = true + if strings.HasPrefix(p, "client_max_window_bits=") { + // We can't adjust the deflate window, but decoding with a larger window is acceptable. continue } - - // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead - // of ignoring it as the draft spec is unclear. It says the server can ignore it - // but the server has no way of signalling to the client it was ignored as the parameters - // are set one way. - // Thus us ignoring it would make the client think we understood it which would cause issues. - // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 - // - // Either way, we're only implementing this for webkit which never sends the max_window_bits - // parameter so we don't need to worry about it. - err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err - } - - s := "x-webkit-deflate-frame" - if copts.clientNoContextTakeover { - s += "; no_context_takeover" + return nil, false } - w.Header().Set("Sec-WebSocket-Extensions", s) - - return copts, nil + return copts, true } func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { diff --git a/accept_js.go b/accept_js.go deleted file mode 100644 index daad4b79..00000000 --- a/accept_js.go +++ /dev/null @@ -1,20 +0,0 @@ -package websocket - -import ( - "errors" - "net/http" -) - -// AcceptOptions represents Accept's options. -type AcceptOptions struct { - Subprotocols []string - InsecureSkipVerify bool - OriginPatterns []string - CompressionMode CompressionMode - CompressionThreshold int -} - -// Accept is stubbed out for Wasm. -func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - return nil, errors.New("unimplemented") -} diff --git a/accept_test.go b/accept_test.go index e114d1ad..7cb85d0f 100644 --- a/accept_test.go +++ b/accept_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -12,6 +13,7 @@ import ( "testing" "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/xrand" ) func TestAccept(t *testing.T) { @@ -35,28 +37,76 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) - assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`) + assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`) }) - t.Run("badCompression", func(t *testing.T) { + // #247 + t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) { t.Parallel() - w := mockHijacker{ - ResponseWriter: httptest.NewRecorder(), - } + w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") - r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + r.Header.Set("Origin", "https://harhar.com") _, err := Accept(w, r, nil) - assert.Contains(t, err, `unsupported permessage-deflate parameter`) + assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`) + }) + + t.Run("badCompression", func(t *testing.T) { + t.Parallel() + + newRequest := func(extensions string) *http.Request { + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + r.Header.Set("Sec-WebSocket-Extensions", extensions) + return r + } + errHijack := errors.New("hijack error") + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errHijack + }, + } + } + + t.Run("withoutFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar") + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.ErrorIs(t, errHijack, err) + assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "") + }) + t.Run("withFallback", func(t *testing.T) { + t.Parallel() + + w := newResponseWriter() + r := newRequest("permessage-deflate; harharhar, permessage-deflate") + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionNoContextTakeover, + }) + assert.ErrorIs(t, errHijack, err) + assert.Equal(t, "extension header", + w.Header().Get("Sec-WebSocket-Extensions"), + CompressionNoContextTakeover.opts().String(), + ) + }) }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -67,7 +117,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) _, err := Accept(w, r, nil) assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) @@ -87,7 +137,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") r.Header.Set("Sec-WebSocket-Version", "13") - r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) @@ -134,7 +184,15 @@ func Test_verifyClientHandshake(t *testing.T) { }, }, { - name: "badWebSocketKey", + name: "missingWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + }, + }, + { + name: "emptyWebSocketKey", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", @@ -142,13 +200,43 @@ func Test_verifyClientHandshake(t *testing.T) { "Sec-WebSocket-Key": "", }, }, + { + name: "shortWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": xrand.Base64(15), + }, + }, + { + name: "invalidWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "notbase64", + }, + }, + { + name: "extraWebSocketKey", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + // Kinda cheeky, but http headers are case-insensitive. + // If 2 sec keys are present, this is a failure condition. + "Sec-WebSocket-Key": xrand.Base64(16), + "sec-webSocket-key": xrand.Base64(16), + }, + }, { name: "badHTTPVersion", h: map[string]string{ "Connection": "Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", - "Sec-WebSocket-Key": "meow123", + "Sec-WebSocket-Key": xrand.Base64(16), }, http1: true, }, @@ -158,7 +246,17 @@ func Test_verifyClientHandshake(t *testing.T) { "Connection": "keep-alive, Upgrade", "Upgrade": "websocket", "Sec-WebSocket-Version": "13", - "Sec-WebSocket-Key": "meow123", + "Sec-WebSocket-Key": xrand.Base64(16), + }, + success: true, + }, + { + name: "successSecKeyExtraSpace", + h: map[string]string{ + "Connection": "keep-alive, Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": " " + xrand.Base64(16) + " ", }, success: true, }, @@ -178,7 +276,7 @@ func Test_verifyClientHandshake(t *testing.T) { } for k, v := range tc.h { - r.Header.Set(k, v) + r.Header.Add(k, v) } _, err := verifyClientRequest(httptest.NewRecorder(), r) @@ -325,59 +423,54 @@ func Test_authenticateOrigin(t *testing.T) { } } -func Test_acceptCompression(t *testing.T) { +func Test_selectDeflate(t *testing.T) { t.Parallel() testCases := []struct { - name string - mode CompressionMode - reqSecWebSocketExtensions string - respSecWebSocketExtensions string - expCopts *compressionOptions - error bool + name string + mode CompressionMode + header string + expCopts *compressionOptions + expOK bool }{ { name: "disabled", mode: CompressionDisabled, expCopts: nil, + expOK: false, }, { name: "noClientSupport", mode: CompressionNoContextTakeover, expCopts: nil, + expOK: false, }, { - name: "permessage-deflate", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", - respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; client_max_window_bits", expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, }, + expOK: true, + }, + { + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow", + expOK: false, }, { - name: "permessage-deflate/error", - mode: CompressionNoContextTakeover, - reqSecWebSocketExtensions: "permessage-deflate; meow", - error: true, + name: "permessage-deflate/unknown-parameter", + mode: CompressionNoContextTakeover, + header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + expOK: true, }, - // { - // name: "x-webkit-deflate-frame", - // mode: CompressionNoContextTakeover, - // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", - // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", - // expCopts: &compressionOptions{ - // clientNoContextTakeover: true, - // serverNoContextTakeover: true, - // }, - // }, - // { - // name: "x-webkit-deflate/error", - // mode: CompressionNoContextTakeover, - // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", - // error: true, - // }, } for _, tc := range testCases { @@ -385,19 +478,11 @@ func Test_acceptCompression(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) - - w := httptest.NewRecorder() - copts, err := acceptCompression(r, w, tc.mode) - if tc.error { - assert.Error(t, err) - return - } - - assert.Success(t, err) + h := http.Header{} + h.Set("Sec-WebSocket-Extensions", tc.header) + copts, ok := selectDeflate(websocketExtensions(h), tc.mode) + assert.Equal(t, "selected options", tc.expOK, ok) assert.Equal(t, "compression options", tc.expCopts, copts) - assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } diff --git a/autobahn_test.go b/autobahn_test.go index e56a4912..57ceebd5 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket_test @@ -5,8 +6,9 @@ package websocket_test import ( "context" "encoding/json" + "errors" "fmt" - "io/ioutil" + "io" "net" "os" "os/exec" @@ -19,6 +21,7 @@ import ( "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" + "nhooyr.io/websocket/internal/util" ) var excludedAutobahnCases = []string{ @@ -28,25 +31,43 @@ var excludedAutobahnCases = []string{ // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 - // Same with klauspost/compress which doesn't allow adjusting the sliding window size. "13.3.*", "13.4.*", "13.5.*", "13.6.*", } var autobahnCases = []string{"*"} +// Used to run individual test cases. autobahnCases runs only those cases matched +// and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases +// is niled. +var onlyAutobahnCases = []string{} + func TestAutobahn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN_TEST") == "" { + if os.Getenv("AUTOBAHN") == "" { t.SkipNow() } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) + if os.Getenv("AUTOBAHN") == "fast" { + // These are the slow tests. + excludedAutobahnCases = append(excludedAutobahnCases, + "9.*", "12.*", "13.*", + ) + } + + if len(onlyAutobahnCases) > 0 { + excludedAutobahnCases = []string{} + autobahnCases = onlyAutobahnCases + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) defer cancel() - wstestURL, closeFn, err := wstestClientServer(ctx) + wstestURL, closeFn, err := wstestServer(t, ctx) assert.Success(t, err) - defer closeFn() + defer func() { + assert.Success(t, closeFn()) + }() err = waitWS(ctx, wstestURL) assert.Success(t, err) @@ -61,7 +82,9 @@ func TestAutobahn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{ + CompressionMode: websocket.CompressionContextTakeover, + }) assert.Success(t, err) err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) @@ -73,7 +96,7 @@ func TestAutobahn(t *testing.T) { assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") + checkWSTestIndex(t, "./ci/out/autobahn-report/index.json") } func waitWS(ctx context.Context, url string) error { @@ -92,17 +115,24 @@ func waitWS(ctx context.Context, url string) error { return ctx.Err() } -func wstestClientServer(ctx context.Context) (url string, closeFn func(), err error) { +func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) { + defer errd.Wrap(&err, "failed to start autobahn wstest server") + serverAddr, err := unusedListenAddr() if err != nil { return "", nil, err } + _, serverPort, err := net.SplitHostPort(serverAddr) + if err != nil { + return "", nil, err + } url = "ws://" + serverAddr + const outDir = "ci/out/autobahn-report" specFile, err := tempJSONFile(map[string]interface{}{ "url": url, - "outdir": "ci/out/wstestClientReports", + "outdir": outDir, "cases": autobahnCases, "exclude-cases": excludedAutobahnCases, }) @@ -110,26 +140,71 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er return "", nil, fmt.Errorf("failed to write spec: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) + ctx, cancel := context.WithTimeout(ctx, time.Hour) defer func() { if err != nil { cancel() } }() - args := []string{"--mode", "fuzzingserver", "--spec", specFile, + dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite") + dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) { + tb.Log(string(p)) + return len(p), nil + }) + dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) { + tb.Log(string(p)) + return len(p), nil + }) + tb.Log(dockerPull) + err = dockerPull.Run() + if err != nil { + return "", nil, fmt.Errorf("failed to pull docker image: %w", err) + } + + wd, err := os.Getwd() + if err != nil { + return "", nil, err + } + + var args []string + args = append(args, "run", "-i", "--rm", + "-v", fmt.Sprintf("%s:%[1]s", specFile), + "-v", fmt.Sprintf("%s/ci:/ci", wd), + fmt.Sprintf("-p=%s:%s", serverAddr, serverPort), + "crossbario/autobahn-testsuite", + ) + args = append(args, "wstest", "--mode", "fuzzingserver", "--spec", specFile, // Disables some server that runs as part of fuzzingserver mode. // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 "--webport=0", - } - wstest := exec.CommandContext(ctx, "wstest", args...) + ) + wstest := exec.CommandContext(ctx, "docker", args...) + wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) { + tb.Log(string(p)) + return len(p), nil + }) + wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) { + tb.Log(string(p)) + return len(p), nil + }) + tb.Log(wstest) err = wstest.Start() if err != nil { return "", nil, fmt.Errorf("failed to start wstest: %w", err) } - return url, func() { - wstest.Process.Kill() + return url, func() error { + err = wstest.Process.Kill() + if err != nil { + return fmt.Errorf("failed to kill wstest: %w", err) + } + err = wstest.Wait() + var ee *exec.ExitError + if errors.As(err, &ee) && ee.ExitCode() == -1 { + return nil + } + return err }, nil } @@ -146,7 +221,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { if err != nil { return 0, err } - b, err := ioutil.ReadAll(r) + b, err := io.ReadAll(r) if err != nil { return 0, err } @@ -161,7 +236,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { } func checkWSTestIndex(t *testing.T, path string) { - wstestOut, err := ioutil.ReadFile(path) + wstestOut, err := os.ReadFile(path) assert.Success(t, err) var indexJSON map[string]map[string]struct { @@ -206,7 +281,7 @@ func unusedListenAddr() (_ string, err error) { } func tempJSONFile(v interface{}) (string, error) { - f, err := ioutil.TempFile("", "temp.json") + f, err := os.CreateTemp("", "temp.json") if err != nil { return "", fmt.Errorf("temp file: %w", err) } diff --git a/ci/all.sh b/ci/all.sh deleted file mode 100755 index 1ee7640f..00000000 --- a/ci/all.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -main() { - cd "$(dirname "$0")/.." - - ./ci/fmt.sh - ./ci/lint.sh - ./ci/test.sh "$@" -} - -main "$@" diff --git a/ci/bench.sh b/ci/bench.sh new file mode 100755 index 00000000..a553b93a --- /dev/null +++ b/ci/bench.sh @@ -0,0 +1,9 @@ +#!/bin/sh +set -eu +cd -- "$(dirname "$0")/.." + +go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" . +( + cd ./internal/thirdparty + go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" . +) diff --git a/ci/container/Dockerfile b/ci/container/Dockerfile deleted file mode 100644 index 0c6c2a54..00000000 --- a/ci/container/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM golang - -RUN apt-get update -RUN apt-get install -y npm shellcheck chromium - -ENV GO111MODULE=on -RUN go get golang.org/x/tools/cmd/goimports -RUN go get mvdan.cc/sh/v3/cmd/shfmt -RUN go get golang.org/x/tools/cmd/stringer -RUN go get golang.org/x/lint/golint -RUN go get github.com/agnivade/wasmbrowsertest - -RUN npm --unsafe-perm=true install -g prettier -RUN npm --unsafe-perm=true install -g netlify-cli diff --git a/ci/fmt.sh b/ci/fmt.sh index e6a2d689..6e5a68e4 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -1,38 +1,20 @@ -#!/usr/bin/env bash -set -euo pipefail +#!/bin/sh +set -eu +cd -- "$(dirname "$0")/.." -main() { - cd "$(dirname "$0")/.." +go mod tidy +(cd ./internal/thirdparty && go mod tidy) +(cd ./internal/examples && go mod tidy) +gofmt -w -s . +go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" . - go mod tidy - gofmt -w -s . - goimports -w "-local=$(go list -m)" . +npx prettier@3.0.3 \ + --write \ + --log-level=warn \ + --print-width=90 \ + --no-semi \ + --single-quote \ + --arrow-parens=avoid \ + $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") - prettier \ - --write \ - --print-width=120 \ - --no-semi \ - --trailing-comma=all \ - --loglevel=warn \ - --arrow-parens=avoid \ - $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") - shfmt -i 2 -w -s -sr $(git ls-files "*.sh") - - stringer -type=opcode,MessageType,StatusCode -output=stringer.go - - if [[ ${CI-} ]]; then - ensure_fmt - fi -} - -ensure_fmt() { - if [[ $(git ls-files --other --modified --exclude-standard) ]]; then - git -c color.ui=always --no-pager diff - echo - echo "Please run the following locally:" - echo " ./ci/fmt.sh" - exit 1 - fi -} - -main "$@" +go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go diff --git a/ci/lint.sh b/ci/lint.sh index e1053d13..3cf8eee4 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,16 +1,33 @@ -#!/usr/bin/env bash -set -euo pipefail +#!/bin/sh +set -eu +cd -- "$(dirname "$0")/.." -main() { - cd "$(dirname "$0")/.." +go vet ./... +GOOS=js GOARCH=wasm go vet ./... - go vet ./... - GOOS=js GOARCH=wasm go vet ./... - - golint -set_exit_status ./... - GOOS=js GOARCH=wasm golint -set_exit_status ./... +go install honnef.co/go/tools/cmd/staticcheck@latest +staticcheck ./... +GOOS=js GOARCH=wasm staticcheck ./... - shellcheck --exclude=SC2046 $(git ls-files "*.sh") +govulncheck() { + tmpf=$(mktemp) + if ! command govulncheck "$@" >"$tmpf" 2>&1; then + cat "$tmpf" + fi } +go install golang.org/x/vuln/cmd/govulncheck@latest +govulncheck ./... +GOOS=js GOARCH=wasm govulncheck ./... -main "$@" +( + cd ./internal/examples + go vet ./... + staticcheck ./... + govulncheck ./... +) +( + cd ./internal/thirdparty + go vet ./... + staticcheck ./... + govulncheck ./... +) diff --git a/ci/test.sh b/ci/test.sh index 95ef7101..83bb9832 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -1,25 +1,23 @@ -#!/usr/bin/env bash -set -euo pipefail +#!/bin/sh +set -eu +cd -- "$(dirname "$0")/.." -main() { - cd "$(dirname "$0")/.." +( + cd ./internal/examples + go test "$@" ./... +) +( + cd ./internal/thirdparty + go test "$@" ./... +) - go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... "$@" ./... - sed -i '/stringer\.go/d' ci/out/coverage.prof - sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof - sed -i '/examples/d' ci/out/coverage.prof +go install github.com/agnivade/wasmbrowsertest@latest +go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... +sed -i.bak '/stringer\.go/d' ci/out/coverage.prof +sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof +sed -i.bak '/examples/d' ci/out/coverage.prof - # Last line is the total coverage. - go tool cover -func ci/out/coverage.prof | tail -n1 +# Last line is the total coverage. +go tool cover -func ci/out/coverage.prof | tail -n1 - go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html - - if [[ ${CI-} && ${GITHUB_REF-} == *master ]]; then - local deployDir - deployDir="$(mktemp -d)" - cp ci/out/coverage.html "$deployDir/index.html" - netlify deploy --prod "--dir=$deployDir" - fi -} - -main "$@" +go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html diff --git a/close.go b/close.go index 7cbc19e9..c3dee7e0 100644 --- a/close.go +++ b/close.go @@ -1,8 +1,17 @@ +//go:build !js +// +build !js + package websocket import ( + "context" + "encoding/binary" "errors" "fmt" + "net" + "time" + + "nhooyr.io/websocket/internal/errd" ) // StatusCode represents a WebSocket status code. @@ -74,3 +83,217 @@ func CloseStatus(err error) StatusCode { } return -1 } + +// Close performs the WebSocket close handshake with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// All data messages received from the peer during the close handshake will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes. Avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + defer c.wg.Wait() + return c.closeHandshake(code, reason) +} + +// CloseNow closes the WebSocket connection without attempting a close handshake. +// Use when you do not want the overhead of the close handshake. +func (c *Conn) CloseNow() (err error) { + defer c.wg.Wait() + defer errd.Wrap(&err, "failed to close WebSocket") + + if c.isClosed() { + return net.ErrClosed + } + + c.close(nil) + return c.closeErr +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + writeErr := c.writeClose(code, reason) + closeHandshakeErr := c.waitCloseHandshake() + + if writeErr != nil { + return writeErr + } + + if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) { + return closeHandshakeErr + } + + return nil +} + +func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + wroteClose := c.wroteClose + c.wroteClose = true + c.closeMu.Unlock() + if wroteClose { + return net.ErrClosed + } + + ce := CloseError{ + Code: code, + Reason: reason, + } + + var p []byte + var marshalErr error + if ce.Code != StatusNoStatusRcvd { + p, marshalErr = ce.bytes() + } + + writeErr := c.writeControl(context.Background(), opClose, p) + if CloseStatus(writeErr) != -1 { + // Not a real error if it's due to a close frame being received. + writeErr = nil + } + + // We do this after in case there was an error writing the close frame. + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + + if marshalErr != nil { + return marshalErr + } + return writeErr +} + +func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.readMu.lock(ctx) + if err != nil { + return err + } + defer c.readMu.unlock() + + if c.readCloseFrameErr != nil { + return c.readCloseFrameErr + } + + for i := int64(0); i < c.msgReader.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() ([]byte, error) { + p, err := ce.bytesErr() + if err != nil { + err = fmt.Errorf("failed to marshal close frame: %w", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() + } + return p, err +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrLocked(err) + c.closeMu.Unlock() +} + +func (c *Conn) setCloseErrLocked(err error) { + if c.closeErr == nil && err != nil { + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + } +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/close_notjs.go b/close_notjs.go deleted file mode 100644 index 4251311d..00000000 --- a/close_notjs.go +++ /dev/null @@ -1,211 +0,0 @@ -// +build !js - -package websocket - -import ( - "context" - "encoding/binary" - "errors" - "fmt" - "log" - "time" - - "nhooyr.io/websocket/internal/errd" -) - -// Close performs the WebSocket close handshake with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for -// the peer to send a close frame. -// All data messages received from the peer during the close handshake will be discarded. -// -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes. Avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection once -// complete. -func (c *Conn) Close(code StatusCode, reason string) error { - return c.closeHandshake(code, reason) -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - writeErr := c.writeClose(code, reason) - closeHandshakeErr := c.waitCloseHandshake() - - if writeErr != nil { - return writeErr - } - - if CloseStatus(closeHandshakeErr) == -1 { - return closeHandshakeErr - } - - return nil -} - -var errAlreadyWroteClose = errors.New("already wrote close") - -func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - wroteClose := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if wroteClose { - return errAlreadyWroteClose - } - - ce := CloseError{ - Code: code, - Reason: reason, - } - - var p []byte - var marshalErr error - if ce.Code != StatusNoStatusRcvd { - p, marshalErr = ce.bytes() - if marshalErr != nil { - log.Printf("websocket: %v", marshalErr) - } - } - - writeErr := c.writeControl(context.Background(), opClose, p) - if CloseStatus(writeErr) != -1 { - // Not a real error if it's due to a close frame being received. - writeErr = nil - } - - // We do this after in case there was an error writing the close frame. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) - - if marshalErr != nil { - return marshalErr - } - return writeErr -} - -func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err := c.readMu.lock(ctx) - if err != nil { - return err - } - defer c.readMu.unlock() - - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - - for { - h, err := c.readLoop(ctx) - if err != nil { - return err - } - - for i := int64(0); i < h.payloadLength; i++ { - _, err := c.br.ReadByte() - if err != nil { - return err - } - } - } -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil - } - - if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), - } - - if !validWireCloseCode(ce.Code) { - return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) - } - - return ce, nil -} - -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false - } - - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true - } - if code >= 3000 && code <= 4999 { - return true - } - - return false -} - -func (ce CloseError) bytes() ([]byte, error) { - p, err := ce.bytesErr() - if err != nil { - err = fmt.Errorf("failed to marshal close frame: %w", err) - ce = CloseError{ - Code: StatusInternalError, - } - p, _ = ce.bytesErr() - } - return p, err -} - -const maxCloseReason = maxControlPayload - 2 - -func (ce CloseError) bytesErr() ([]byte, error) { - if len(ce.Reason) > maxCloseReason { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) - } - - if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) - } - - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil -} - -func (c *Conn) setCloseErr(err error) { - c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) - } -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} diff --git a/close_test.go b/close_test.go index 00a48d9e..6bf3c256 100644 --- a/close_test.go +++ b/close_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket diff --git a/compress.go b/compress.go index 80b46d1c..1f3adcfb 100644 --- a/compress.go +++ b/compress.go @@ -1,39 +1,233 @@ +//go:build !js +// +build !js + package websocket -// CompressionMode represents the modes available to the deflate extension. +import ( + "compress/flate" + "io" + "sync" +) + +// CompressionMode represents the modes available to the permessage-deflate extension. // See https://tools.ietf.org/html/rfc7692 // -// A compatibility layer is implemented for the older deflate-frame extension used -// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 -// It will work the same in every way except that we cannot signal to the peer we -// want to use no context takeover on our side, we can only signal that they should. -// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 +// Works in all modern browsers except Safari which does not implement the permessage-deflate extension. +// +// Compression is only used if the peer supports the mode selected. type CompressionMode int const ( - // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed - // for every message. This applies to both server and client side. - // - // This means less efficient compression as the sliding window from previous messages - // will not be used but the memory overhead will be lower if the connections - // are long lived and seldom used. + // CompressionDisabled disables the negotiation of the permessage-deflate extension. // - // The message will only be compressed if greater than 512 bytes. - CompressionNoContextTakeover CompressionMode = iota + // This is the default. Do not enable compression without benchmarking for your particular use case first. + CompressionDisabled CompressionMode = iota - // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. - // This enables reusing the sliding window from previous messages. - // As most WebSocket protocols are repetitive, this can be very efficient. - // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from + // previous messages. i.e compression context across messages is preserved. + // + // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient. // - // If the peer negotiates NoContextTakeover on the client or server side, it will be - // used instead as this is required by the RFC. + // The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's + // that are used when reading and then returned. + // + // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently. + // + // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover. CompressionContextTakeover - // CompressionDisabled disables the deflate extension. + // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with + // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from + // a sync.Pool. + // + // This means less efficient compression as the sliding window from previous messages will not be used but the + // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window. + // Especially if the connections are long lived and seldom written to. // - // Use this if you are using a predominantly binary protocol with very - // little duplication in between messages or CPU and memory are more - // important than bandwidth. - CompressionDisabled + // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently. + // + // If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled. + CompressionNoContextTakeover ) + +func (m CompressionMode) opts() *compressionOptions { + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) String() string { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + return s +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to read more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + if tw != nil && tw.tail != nil { + tw.tail = tw.tail[:0] + } +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + if tw.tail == nil { + tw.tail = make([]byte, 0, 4) + } + + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + + // Shift remaining bytes in tail over. + n := copy(tw.tail, tw.tail[extra:]) + tw.tail = tw.tail[:n] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader, dict []byte) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReaderDict(r, dict) + } + fr.(flate.Resetter).Reset(r, dict) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +var flateWriterPool sync.Pool + +func getFlateWriter(w io.Writer) *flate.Writer { + fw, ok := flateWriterPool.Get().(*flate.Writer) + if !ok { + fw, _ = flate.NewWriter(w, flate.BestSpeed) + return fw + } + fw.Reset(w) + return fw +} + +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) +} + +type slidingWindow struct { + buf []byte +} + +var swPoolMu sync.RWMutex +var swPool = map[int]*sync.Pool{} + +func slidingWindowPool(n int) *sync.Pool { + swPoolMu.RLock() + p, ok := swPool[n] + swPoolMu.RUnlock() + if ok { + return p + } + + p = &sync.Pool{} + + swPoolMu.Lock() + swPool[n] = p + swPoolMu.Unlock() + + return p +} + +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return + } + + if n == 0 { + n = 32768 + } + + p := slidingWindowPool(n) + sw2, ok := p.Get().(*slidingWindow) + if ok { + *sw = *sw2 + } else { + sw.buf = make([]byte, 0, n) + } +} + +func (sw *slidingWindow) close() { + sw.buf = sw.buf[:0] + swPoolMu.Lock() + swPool[cap(sw.buf)].Put(sw) + swPoolMu.Unlock() +} + +func (sw *slidingWindow) write(p []byte) { + if len(p) >= cap(sw.buf) { + sw.buf = sw.buf[:cap(sw.buf)] + p = p[len(p)-cap(sw.buf):] + copy(sw.buf, p) + return + } + + left := cap(sw.buf) - len(sw.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(sw.buf, sw.buf[spaceNeeded:]) + sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] + } + + sw.buf = append(sw.buf, p...) +} diff --git a/compress_notjs.go b/compress_notjs.go deleted file mode 100644 index 809a272c..00000000 --- a/compress_notjs.go +++ /dev/null @@ -1,181 +0,0 @@ -// +build !js - -package websocket - -import ( - "io" - "net/http" - "sync" - - "github.com/klauspost/compress/flate" -) - -func (m CompressionMode) opts() *compressionOptions { - return &compressionOptions{ - clientNoContextTakeover: m == CompressionNoContextTakeover, - serverNoContextTakeover: m == CompressionNoContextTakeover, - } -} - -type compressionOptions struct { - clientNoContextTakeover bool - serverNoContextTakeover bool -} - -func (copts *compressionOptions) setHeader(h http.Header) { - s := "permessage-deflate" - if copts.clientNoContextTakeover { - s += "; client_no_context_takeover" - } - if copts.serverNoContextTakeover { - s += "; server_no_context_takeover" - } - h.Set("Sec-WebSocket-Extensions", s) -} - -// These bytes are required to get flate.Reader to return. -// They are removed when sending to avoid the overhead as -// WebSocket framing tell's when the message has ended but then -// we need to add them back otherwise flate.Reader keeps -// trying to return more bytes. -const deflateMessageTail = "\x00\x00\xff\xff" - -type trimLastFourBytesWriter struct { - w io.Writer - tail []byte -} - -func (tw *trimLastFourBytesWriter) reset() { - if tw != nil && tw.tail != nil { - tw.tail = tw.tail[:0] - } -} - -func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { - if tw.tail == nil { - tw.tail = make([]byte, 0, 4) - } - - extra := len(tw.tail) + len(p) - 4 - - if extra <= 0 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Now we need to write as many extra bytes as we can from the previous tail. - if extra > len(tw.tail) { - extra = len(tw.tail) - } - if extra > 0 { - _, err := tw.w.Write(tw.tail[:extra]) - if err != nil { - return 0, err - } - - // Shift remaining bytes in tail over. - n := copy(tw.tail, tw.tail[extra:]) - tw.tail = tw.tail[:n] - } - - // If p is less than or equal to 4 bytes, - // all of it is is part of the tail. - if len(p) <= 4 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Otherwise, only the last 4 bytes are. - tw.tail = append(tw.tail, p[len(p)-4:]...) - - p = p[:len(p)-4] - n, err := tw.w.Write(p) - return n + 4, err -} - -var flateReaderPool sync.Pool - -func getFlateReader(r io.Reader, dict []byte) io.Reader { - fr, ok := flateReaderPool.Get().(io.Reader) - if !ok { - return flate.NewReaderDict(r, dict) - } - fr.(flate.Resetter).Reset(r, dict) - return fr -} - -func putFlateReader(fr io.Reader) { - flateReaderPool.Put(fr) -} - -type slidingWindow struct { - buf []byte -} - -var swPoolMu sync.RWMutex -var swPool = map[int]*sync.Pool{} - -func slidingWindowPool(n int) *sync.Pool { - swPoolMu.RLock() - p, ok := swPool[n] - swPoolMu.RUnlock() - if ok { - return p - } - - p = &sync.Pool{} - - swPoolMu.Lock() - swPool[n] = p - swPoolMu.Unlock() - - return p -} - -func (sw *slidingWindow) init(n int) { - if sw.buf != nil { - return - } - - if n == 0 { - n = 32768 - } - - p := slidingWindowPool(n) - buf, ok := p.Get().([]byte) - if ok { - sw.buf = buf[:0] - } else { - sw.buf = make([]byte, 0, n) - } -} - -func (sw *slidingWindow) close() { - if sw.buf == nil { - return - } - - swPoolMu.Lock() - swPool[cap(sw.buf)].Put(sw.buf) - swPoolMu.Unlock() - sw.buf = nil -} - -func (sw *slidingWindow) write(p []byte) { - if len(p) >= cap(sw.buf) { - sw.buf = sw.buf[:cap(sw.buf)] - p = p[len(p)-cap(sw.buf):] - copy(sw.buf, p) - return - } - - left := cap(sw.buf) - len(sw.buf) - if left < len(p) { - // We need to shift spaceNeeded bytes from the end to make room for p at the end. - spaceNeeded := len(p) - left - copy(sw.buf, sw.buf[spaceNeeded:]) - sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] - } - - sw.buf = append(sw.buf, p...) -} diff --git a/compress_test.go b/compress_test.go index 2c4c896c..667e1408 100644 --- a/compress_test.go +++ b/compress_test.go @@ -1,8 +1,12 @@ +//go:build !js // +build !js package websocket import ( + "bytes" + "compress/flate" + "io" "strings" "testing" @@ -32,3 +36,27 @@ func Test_slidingWindow(t *testing.T) { }) } } + +func BenchmarkFlateWriter(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + w, _ := flate.NewWriter(io.Discard, flate.BestSpeed) + // We have to write a byte to get the writer to allocate to its full extent. + w.Write([]byte{'a'}) + w.Flush() + } +} + +func BenchmarkFlateReader(b *testing.B) { + b.ReportAllocs() + + var buf bytes.Buffer + w, _ := flate.NewWriter(&buf, flate.BestSpeed) + w.Write([]byte{'a'}) + w.Flush() + + for i := 0; i < b.N; i++ { + r := flate.NewReader(bytes.NewReader(buf.Bytes())) + io.ReadAll(r) + } +} diff --git a/conn.go b/conn.go index a41808be..e133cd67 100644 --- a/conn.go +++ b/conn.go @@ -1,5 +1,21 @@ +//go:build !js +// +build !js + package websocket +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "net" + "runtime" + "strconv" + "sync" + "sync/atomic" +) + // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int @@ -11,3 +27,285 @@ const ( // MessageBinary is for binary messages like protobufs. MessageBinary ) + +// Conn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +// +// This applies to context expirations as well unfortunately. +// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 +type Conn struct { + noCopy + + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader + readCloseFrameErr error + + // Write state. + msgWriter *msgWriter + writeFrameMu *mu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header + + wg sync.WaitGroup + closed chan struct{} + closeMu sync.Mutex + closeErr error + wroteClose bool + + pingCounter int32 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} +} + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *Conn { + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + } + + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriter = newMsgWriter(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 128 + if !c.msgWriter.flateContextTakeover() { + c.flateThreshold = 512 + } + } + + runtime.SetFinalizer(c, func(c *Conn) { + c.close(errors.New("connection garbage collected")) + }) + + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.timeoutLoop() + }() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +func (c *Conn) close(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return + } + if err == nil { + err = c.rwc.Close() + } + c.setCloseErrLocked(err) + + close(c.closed) + runtime.SetFinalizer(c, nil) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() + + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.msgWriter.close() + c.msgReader.close() + }() +} + +func (c *Conn) timeoutLoop() { + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.writeError(StatusPolicyViolation, errors.New("read timed out")) + }() + case <-writeCtx.Done(): + c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + return + } + } +} + +func (c *Conn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *Conn) Ping(ctx context.Context) error { + p := atomic.AddInt32(&c.pingCounter, 1) + + err := c.ping(ctx, strconv.Itoa(int(p))) + if err != nil { + return fmt.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}, 1) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return net.ErrClosed + case <-ctx.Done(): + err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) + c.close(err) + return err + case <-pong: + return nil + } +} + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *mu) lock(ctx context.Context) error { + select { + case <-m.c.closed: + return net.ErrClosed + case <-ctx.Done(): + err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err) + return err + case m.ch <- struct{}{}: + // To make sure the connection is certainly alive. + // As it's possible the send on m.ch was selected + // over the receive on closed. + select { + case <-m.c.closed: + // Make sure to release. + m.unlock() + return net.ErrClosed + default: + } + return nil + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +} + +type noCopy struct{} + +func (*noCopy) Lock() {} diff --git a/conn_notjs.go b/conn_notjs.go deleted file mode 100644 index 0c85ab77..00000000 --- a/conn_notjs.go +++ /dev/null @@ -1,265 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "context" - "errors" - "fmt" - "io" - "runtime" - "strconv" - "sync" - "sync/atomic" -) - -// Conn represents a WebSocket connection. -// All methods may be called concurrently except for Reader and Read. -// -// You must always read from the connection. Otherwise control -// frames will not be handled. See Reader and CloseRead. -// -// Be sure to call Close on the connection when you -// are finished with it to release associated resources. -// -// On any error from any method, the connection is closed -// with an appropriate reason. -type Conn struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - br *bufio.Reader - bw *bufio.Writer - - readTimeout chan context.Context - writeTimeout chan context.Context - - // Read state. - readMu *mu - readHeaderBuf [8]byte - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error - - // Write state. - msgWriterState *msgWriterState - writeFrameMu *mu - writeBuf []byte - writeHeaderBuf [8]byte - writeHeader header - - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool - - pingCounter int32 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} -} - -type connConfig struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - - br *bufio.Reader - bw *bufio.Writer -} - -func newConn(cfg connConfig) *Conn { - c := &Conn{ - subprotocol: cfg.subprotocol, - rwc: cfg.rwc, - client: cfg.client, - copts: cfg.copts, - flateThreshold: cfg.flateThreshold, - - br: cfg.br, - bw: cfg.bw, - - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), - } - - c.readMu = newMu(c) - c.writeFrameMu = newMu(c) - - c.msgReader = newMsgReader(c) - - c.msgWriterState = newMsgWriterState(c) - if c.client { - c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) - } - - if c.flate() && c.flateThreshold == 0 { - c.flateThreshold = 128 - if !c.msgWriterState.flateContextTakeover() { - c.flateThreshold = 512 - } - } - - runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) - }) - - go c.timeoutLoop() - - return c -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - -func (c *Conn) close(err error) { - c.closeMu.Lock() - defer c.closeMu.Unlock() - - if c.isClosed() { - return - } - c.setCloseErrLocked(err) - close(c.closed) - runtime.SetFinalizer(c, nil) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.rwc.Close() - - go func() { - c.msgWriterState.close() - - c.msgReader.close() - }() -} - -func (c *Conn) timeoutLoop() { - readCtx := context.Background() - writeCtx := context.Background() - - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - go c.writeError(StatusPolicyViolation, errors.New("timed out")) - case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) - return - } - } -} - -func (c *Conn) flate() bool { - return c.copts != nil -} - -// Ping sends a ping to the peer and waits for a pong. -// Use this to measure latency or ensure the peer is responsive. -// Ping must be called concurrently with Reader as it does -// not read from the connection but instead waits for a Reader call -// to read the pong. -// -// TCP Keepalives should suffice for most use cases. -func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) - - err := c.ping(ctx, strconv.Itoa(int(p))) - if err != nil { - return fmt.Errorf("failed to ping: %w", err) - } - return nil -} - -func (c *Conn) ping(ctx context.Context, p string) error { - pong := make(chan struct{}, 1) - - c.activePingsMu.Lock() - c.activePings[p] = pong - c.activePingsMu.Unlock() - - defer func() { - c.activePingsMu.Lock() - delete(c.activePings, p) - c.activePingsMu.Unlock() - }() - - err := c.writeControl(ctx, opPing, []byte(p)) - if err != nil { - return err - } - - select { - case <-c.closed: - return c.closeErr - case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err - case <-pong: - return nil - } -} - -type mu struct { - c *Conn - ch chan struct{} -} - -func newMu(c *Conn) *mu { - return &mu{ - c: c, - ch: make(chan struct{}, 1), - } -} - -func (m *mu) forceLock() { - m.ch <- struct{}{} -} - -func (m *mu) lock(ctx context.Context) error { - select { - case <-m.c.closed: - return m.c.closeErr - case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err - case m.ch <- struct{}{}: - // To make sure the connection is certainly alive. - // As it's possible the send on m.ch was selected - // over the receive on closed. - select { - case <-m.c.closed: - // Make sure to release. - m.unlock() - return m.c.closeErr - default: - } - return nil - } -} - -func (m *mu) unlock() { - select { - case <-m.ch: - default: - } -} diff --git a/conn_test.go b/conn_test.go index c2c41292..97b172dc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,13 +1,13 @@ -// +build !js +//go:build !js package websocket_test import ( "bytes" "context" + "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "os" @@ -16,10 +16,6 @@ import ( "testing" "time" - "github.com/gin-gonic/gin" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/duration" - "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" @@ -27,7 +23,6 @@ import ( "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" "nhooyr.io/websocket/wsjson" - "nhooyr.io/websocket/wspb" ) func TestConn(t *testing.T) { @@ -37,7 +32,7 @@ func TestConn(t *testing.T) { t.Parallel() compressionMode := func() websocket.CompressionMode { - return websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)) + return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1)) } for i := 0; i < 5; i++ { @@ -49,7 +44,6 @@ func TestConn(t *testing.T) { CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }) - defer tt.cleanup() tt.goEchoLoop(c2) @@ -67,8 +61,9 @@ func TestConn(t *testing.T) { }) t.Run("badClose", func(t *testing.T) { - tt, c1, _ := newConnTest(t, nil, nil) - defer tt.cleanup() + tt, c1, c2 := newConnTest(t, nil, nil) + + c2.CloseRead(tt.ctx) err := c1.Close(-1, "") assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") @@ -76,7 +71,6 @@ func TestConn(t *testing.T) { t.Run("ping", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) @@ -92,7 +86,6 @@ func TestConn(t *testing.T) { t.Run("badPing", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() c2.CloseRead(tt.ctx) @@ -105,7 +98,6 @@ func TestConn(t *testing.T) { t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() tt.goDiscardLoop(c2) @@ -138,7 +130,6 @@ func TestConn(t *testing.T) { t.Run("concurrentWriteError", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) - defer tt.cleanup() _, err := c1.Writer(tt.ctx, websocket.MessageText) assert.Success(t, err) @@ -147,12 +138,13 @@ func TestConn(t *testing.T) { defer cancel() err = c1.Write(ctx, websocket.MessageText, []byte("x")) - assert.Equal(t, "write error", context.DeadlineExceeded, err) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("unexpected error: %#v", err) + } }) t.Run("netConn", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) @@ -163,8 +155,8 @@ func TestConn(t *testing.T) { n1.SetDeadline(time.Time{}) assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr()) - assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String()) - assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network()) + assert.Equal(t, "remote addr string", "pipe", n1.RemoteAddr().String()) + assert.Equal(t, "remote addr network", "pipe", n1.RemoteAddr().Network()) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) @@ -174,7 +166,7 @@ func TestConn(t *testing.T) { return n2.Close() }) - b, err := ioutil.ReadAll(n1) + b, err := io.ReadAll(n1) assert.Success(t, err) _, err = n1.Read(nil) @@ -192,21 +184,47 @@ func TestConn(t *testing.T) { t.Run("netConn/BadMsg", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) + c2.CloseRead(tt.ctx) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) + return err + }) + + _, err := io.ReadAll(n1) + assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) + + select { + case err := <-errs: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + }) + + t.Run("netConn/readLimit", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) + + s := strings.Repeat("papa", 1<<20) + errs := xsync.Go(func() error { + _, err := n2.Write([]byte(s)) if err != nil { return err } - return nil + return n2.Close() }) - _, err := ioutil.ReadAll(n1) - assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) + b, err := io.ReadAll(n1) + assert.Success(t, err) + + _, err = n1.Read(nil) + assert.Equal(t, "read error", err, io.EOF) select { case err := <-errs: @@ -214,11 +232,24 @@ func TestConn(t *testing.T) { case <-tt.ctx.Done(): t.Fatal(tt.ctx.Err()) } + + assert.Equal(t, "read msg", s, string(b)) + }) + + t.Run("netConn/pastDeadline", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) + + n1.SetDeadline(time.Now().Add(-time.Minute)) + n2.SetDeadline(time.Now().Add(-time.Minute)) + + // No panic we're good. }) t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() tt.goEchoLoop(c2) @@ -246,21 +277,67 @@ func TestConn(t *testing.T) { assert.Success(t, err) }) - t.Run("wspb", func(t *testing.T) { - tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() + t.Run("HTTPClient.Timeout", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ + HTTPClient: &http.Client{Timeout: time.Second * 5}, + }, nil) tt.goEchoLoop(c2) - exp := ptypes.DurationProto(100) - err := wspb.Write(tt.ctx, c1, exp) - assert.Success(t, err) + c1.SetReadLimit(1 << 30) + + exp := xrand.String(xrand.Int(131072)) - act := &duration.Duration{} - err = wspb.Read(tt.ctx, c1, act) + werr := xsync.Go(func() error { + return wsjson.Write(tt.ctx, c1, exp) + }) + + var act interface{} + err := wsjson.Read(tt.ctx, c1, &act) assert.Success(t, err) assert.Equal(t, "read msg", exp, act) + select { + case err := <-werr: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + err = c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + }) + + t.Run("CloseNow", func(t *testing.T) { + _, c1, c2 := newConnTest(t, nil, nil) + + err1 := c1.CloseNow() + err2 := c2.CloseNow() + assert.Success(t, err1) + assert.Success(t, err2) + err1 = c1.CloseNow() + err2 = c2.CloseNow() + assert.ErrorIs(t, websocket.ErrClosed, err1) + assert.ErrorIs(t, websocket.ErrClosed, err2) + }) + + t.Run("MidReadClose", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + tt.goEchoLoop(c2) + + c1.SetReadLimit(131072) + + for i := 0; i < 5; i++ { + err := wstest.Echo(tt.ctx, c1, 131072) + assert.Success(t, err) + } + + err := wsjson.Write(tt.ctx, c1, "four") + assert.Success(t, err) + _, _, err = c1.Reader(tt.ctx) + assert.Success(t, err) + err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) @@ -305,8 +382,6 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error { type connTest struct { t testing.TB ctx context.Context - - doneFuncs []func() } func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { @@ -317,30 +392,20 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) tt = &connTest{t: t, ctx: ctx} - tt.appendDone(cancel) + t.Cleanup(cancel) c1, c2 = wstest.Pipe(dialOpts, acceptOpts) if xrand.Bool() { c1, c2 = c2, c1 } - tt.appendDone(func() { - c2.Close(websocket.StatusInternalError, "") - c1.Close(websocket.StatusInternalError, "") + t.Cleanup(func() { + c2.CloseNow() + c1.CloseNow() }) return tt, c1, c2 } -func (tt *connTest) appendDone(f func()) { - tt.doneFuncs = append(tt.doneFuncs, f) -} - -func (tt *connTest) cleanup() { - for i := len(tt.doneFuncs) - 1; i >= 0; i-- { - tt.doneFuncs[i]() - } -} - func (tt *connTest) goEchoLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) @@ -348,7 +413,7 @@ func (tt *connTest) goEchoLoop(c *websocket.Conn) { err := wstest.EchoLoop(ctx, c) return assertCloseStatus(websocket.StatusNormalClosure, err) }) - tt.appendDone(func() { + tt.t.Cleanup(func() { cancel() err := <-echoLoopErr if err != nil { @@ -370,7 +435,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } } }) - tt.appendDone(func() { + tt.t.Cleanup(func() { cancel() err := <-discardLoopErr if err != nil { @@ -389,7 +454,7 @@ func BenchmarkConn(b *testing.B) { mode: websocket.CompressionDisabled, }, { - name: "compress", + name: "compressContextTakeover", mode: websocket.CompressionContextTakeover, }, { @@ -404,7 +469,6 @@ func BenchmarkConn(b *testing.B) { }, &websocket.AcceptOptions{ CompressionMode: bc.mode, }) - defer bb.cleanup() bb.goEchoLoop(c2) @@ -438,7 +502,7 @@ func BenchmarkConn(b *testing.B) { typ, r, err := c1.Reader(bb.ctx) if err != nil { - b.Fatal(err) + b.Fatal(i, err) } if websocket.MessageText != typ { assert.Equal(b, "data type", websocket.MessageText, typ) @@ -494,36 +558,55 @@ func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOp return assertCloseStatus(websocket.StatusNormalClosure, err) } -func TestGin(t *testing.T) { - t.Parallel() +func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) { + exp := xrand.String(xrand.Int(131072)) - gin.SetMode(gin.ReleaseMode) - r := gin.New() - r.GET("/", func(ginCtx *gin.Context) { - err := echoServer(ginCtx.Writer, ginCtx.Request, nil) - if err != nil { - t.Error(err) - } + werr := xsync.Go(func() error { + return wsjson.Write(ctx, c, exp) }) - s := httptest.NewServer(r) - defer s.Close() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - - c, _, err := websocket.Dial(ctx, s.URL, nil) - assert.Success(t, err) - defer c.Close(websocket.StatusInternalError, "") + var act interface{} + c.SetReadLimit(1 << 30) + err := wsjson.Read(ctx, c, &act) + assert.Success(tb, err) + assert.Equal(tb, "read msg", exp, act) + + select { + case err := <-werr: + assert.Success(tb, err) + case <-ctx.Done(): + tb.Fatal(ctx.Err()) + } +} - err = wsjson.Write(ctx, c, "hello") - assert.Success(t, err) +func assertClose(tb testing.TB, c *websocket.Conn) { + tb.Helper() + err := c.Close(websocket.StatusNormalClosure, "") + assert.Success(tb, err) +} - var v interface{} - err = wsjson.Read(ctx, c, &v) - assert.Success(t, err) - assert.Equal(t, "read msg", "hello", v) +func TestConcurrentClosePing(t *testing.T) { + t.Parallel() + for i := 0; i < 64; i++ { + func() { + c1, c2 := wstest.Pipe(nil, nil) + defer c1.CloseNow() + defer c2.CloseNow() + c1.CloseRead(context.Background()) + c2.CloseRead(context.Background()) + errc := xsync.Go(func() error { + for range time.Tick(time.Millisecond) { + err := c1.Ping(context.Background()) + if err != nil { + return err + } + } + panic("unreachable") + }) - err = c.Close(websocket.StatusNormalClosure, "") - assert.Success(t, err) + time.Sleep(10 * time.Millisecond) + assert.Success(t, c1.Close(websocket.StatusNormalClosure, "")) + <-errc + }() + } } diff --git a/dial.go b/dial.go index 7a7787ff..e4c4daa1 100644 --- a/dial.go +++ b/dial.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -10,7 +11,6 @@ import ( "encoding/base64" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -30,11 +30,15 @@ type DialOptions struct { // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header + // Host optionally overrides the Host HTTP header to send. If empty, the value + // of URL.Host will be used. + Host string + // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string // CompressionMode controls the compression mode. - // Defaults to CompressionNoContextTakeover. + // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode @@ -46,6 +50,45 @@ type DialOptions struct { CompressionThreshold int } +func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { + var cancel context.CancelFunc + + var o DialOptions + if opts != nil { + o = *opts + } + if o.HTTPClient == nil { + o.HTTPClient = http.DefaultClient + } + if o.HTTPClient.Timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout) + + newClient := *o.HTTPClient + newClient.Timeout = 0 + o.HTTPClient = &newClient + } + if o.HTTPHeader == nil { + o.HTTPHeader = http.Header{} + } + newClient := *o.HTTPClient + oldCheckRedirect := o.HTTPClient.CheckRedirect + newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + switch req.URL.Scheme { + case "ws": + req.URL.Scheme = "http" + case "wss": + req.URL.Scheme = "https" + } + if oldCheckRedirect != nil { + return oldCheckRedirect(req, via) + } + return nil + } + o.HTTPClient = &newClient + + return ctx, cancel, &o +} + // Dial performs a WebSocket handshake on url. // // The response is the WebSocket handshake response from the server. @@ -66,26 +109,10 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") - if opts == nil { - opts = &DialOptions{} - } - - opts = &*opts - if opts.HTTPClient == nil { - opts.HTTPClient = http.DefaultClient - } else if opts.HTTPClient.Timeout > 0 { - var cancel context.CancelFunc - - ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) + var cancel context.CancelFunc + ctx, cancel, opts = opts.cloneWithDefaults(ctx) + if cancel != nil { defer cancel() - - newClient := *opts.HTTPClient - newClient.Timeout = 0 - opts.HTTPClient = &newClient - } - - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} } secWebSocketKey, err := secWebSocketKey(rand) @@ -114,9 +141,9 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( }) defer timer.Stop() - b, _ := ioutil.ReadAll(r) + b, _ := io.ReadAll(r) respBody.Close() - resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + resp.Body = io.NopCloser(bytes.NewReader(b)) } }() @@ -157,7 +184,13 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } - req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create new http request: %w", err) + } + if len(opts.Host) > 0 { + req.Host = opts.Host + } req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") @@ -167,7 +200,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if copts != nil { - copts.setHeader(req.Header) + req.Header.Set("Sec-WebSocket-Extensions", copts.String()) } resp, err := opts.HTTPClient.Do(req) @@ -243,7 +276,8 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } - copts = &*copts + _copts := *copts + copts = &_copts for _, p := range ext.params { switch p { @@ -254,6 +288,10 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress copts.serverNoContextTakeover = true continue } + if strings.HasPrefix(p, "server_max_window_bits=") { + // We can't adjust the deflate window, but decoding with a larger window is acceptable. + continue + } return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } diff --git a/dial_test.go b/dial_test.go index 28c255c6..237a2874 100644 --- a/dial_test.go +++ b/dial_test.go @@ -1,19 +1,24 @@ +//go:build !js // +build !js -package websocket +package websocket_test import ( + "bytes" "context" "crypto/rand" "io" - "io/ioutil" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" + "nhooyr.io/websocket" "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/util" + "nhooyr.io/websocket/internal/xsync" ) func TestBadDials(t *testing.T) { @@ -23,10 +28,11 @@ func TestBadDials(t *testing.T) { t.Parallel() testCases := []struct { - name string - url string - opts *DialOptions - rand readerFunc + name string + url string + opts *websocket.DialOptions + rand util.ReaderFunc + nilCtx bool }{ { name: "badURL", @@ -46,6 +52,11 @@ func TestBadDials(t *testing.T) { return 0, io.EOF }, }, + { + name: "nilContext", + url: "http://localhost", + nilCtx: true, + }, } for _, tc := range testCases { @@ -53,14 +64,18 @@ func TestBadDials(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() + var ctx context.Context + var cancel func() + if !tc.nilCtx { + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + } if tc.rand == nil { tc.rand = rand.Reader.Read } - _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) + _, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand) assert.Error(t, err) }) } @@ -72,10 +87,10 @@ func TestBadDials(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ - Body: ioutil.NopCloser(strings.NewReader("hi")), + Body: io.NopCloser(strings.NewReader("hi")), }, nil }), }) @@ -92,22 +107,82 @@ func TestBadDials(t *testing.T) { h := http.Header{} h.Set("Connection", "Upgrade") h.Set("Upgrade", "websocket") - h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) return &http.Response{ StatusCode: http.StatusSwitchingProtocols, Header: h, - Body: ioutil.NopCloser(strings.NewReader("hi")), + Body: io.NopCloser(strings.NewReader("hi")), }, nil } - _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(rt), }) assert.Contains(t, err, "response body is not a io.ReadWriteCloser") }) } +func Test_verifyHostOverride(t *testing.T) { + testCases := []struct { + name string + host string + exp string + }{ + { + name: "noOverride", + host: "", + exp: "example.com", + }, + { + name: "hostOverride", + host: "example.net", + exp: "example.net", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + rt := func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "Host", tc.exp, r.Host) + + h := http.Header{} + h.Set("Connection", "Upgrade") + h.Set("Upgrade", "websocket") + h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: h, + Body: mockBody{bytes.NewBufferString("hi")}, + }, nil + } + + c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ + HTTPClient: mockHTTPClient(rt), + Host: tc.host, + }) + assert.Success(t, err) + c.CloseNow() + }) + } + +} + +type mockBody struct { + *bytes.Buffer +} + +func (mb mockBody) Close() error { + return nil +} + func Test_verifyServerHandshake(t *testing.T) { t.Parallel() @@ -201,18 +276,18 @@ func Test_verifyServerHandshake(t *testing.T) { resp := w.Result() r := httptest.NewRequest("GET", "/", nil) - key, err := secWebSocketKey(rand.Reader) + key, err := websocket.SecWebSocketKey(rand.Reader) assert.Success(t, err) r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { - resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key)) } - opts := &DialOptions{ + opts := &websocket.DialOptions{ Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), } - _, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp) + _, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp) if tc.success { assert.Success(t, err) } else { @@ -233,3 +308,113 @@ type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func TestDialRedirect(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ + HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) { + resp := &http.Response{ + Header: http.Header{}, + } + if r.URL.Scheme != "https" { + resp.Header.Set("Location", "wss://example.com") + resp.StatusCode = http.StatusFound + return resp, nil + } + resp.Header.Set("Connection", "Upgrade") + resp.Header.Set("Upgrade", "meow") + resp.StatusCode = http.StatusSwitchingProtocols + return resp, nil + }), + }) + assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket") +} + +type forwardProxy struct { + hc *http.Client +} + +func newForwardProxy() *forwardProxy { + return &forwardProxy{ + hc: &http.Client{}, + } +} + +func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + defer cancel() + + r = r.WithContext(ctx) + r.RequestURI = "" + resp, err := fc.hc.Do(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + defer resp.Body.Close() + + for k, v := range resp.Header { + w.Header()[k] = v + } + w.Header().Set("PROXIED", "true") + w.WriteHeader(resp.StatusCode) + if resprw, ok := resp.Body.(io.ReadWriter); ok { + c, brw, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + brw.Flush() + + errc1 := xsync.Go(func() error { + _, err := io.Copy(c, resprw) + return err + }) + errc2 := xsync.Go(func() error { + _, err := io.Copy(resprw, c) + return err + }) + select { + case <-errc1: + case <-errc2: + case <-r.Context().Done(): + } + } else { + io.Copy(w, resp.Body) + } +} + +func TestDialViaProxy(t *testing.T) { + t.Parallel() + + ps := httptest.NewServer(newForwardProxy()) + defer ps.Close() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := echoServer(w, r, nil) + assert.Success(t, err) + })) + defer s.Close() + + psu, err := url.Parse(ps.URL) + assert.Success(t, err) + proxyTransport := http.DefaultTransport.(*http.Transport).Clone() + proxyTransport.Proxy = http.ProxyURL(psu) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: proxyTransport, + }, + }) + assert.Success(t, err) + assert.Equal(t, "", "true", resp.Header.Get("PROXIED")) + + assertEcho(t, ctx, c) + assertClose(t, c) +} diff --git a/doc.go b/doc.go index efa920e3..2ab648a6 100644 --- a/doc.go +++ b/doc.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js // Package websocket implements the RFC 6455 WebSocket protocol. @@ -12,11 +13,11 @@ // // The examples are the best way to understand how to correctly use the library. // -// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. +// The wsjson subpackage contain helpers for JSON and protobuf messages. // // More documentation at https://nhooyr.io/websocket. // -// Wasm +// # Wasm // // The client side supports compiling to Wasm. // It wraps the WebSocket browser API. @@ -25,8 +26,9 @@ // // Some important caveats to be aware of: // -// - Accept always errors out -// - Conn.Ping is no-op -// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op -// - *http.Response from Dial is &http.Response{} with a 101 status code on success +// - Accept always errors out +// - Conn.Ping is no-op +// - Conn.CloseNow is Close(StatusGoingAway, "") +// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op +// - *http.Response from Dial is &http.Response{} with a 101 status code on success package websocket // import "nhooyr.io/websocket" diff --git a/example_test.go b/example_test.go index 632c4d6e..590c0411 100644 --- a/example_test.go +++ b/example_test.go @@ -20,7 +20,7 @@ func ExampleAccept() { log.Println(err) return } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) defer cancel() @@ -50,7 +50,7 @@ func ExampleDial() { if err != nil { log.Fatal(err) } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() err = wsjson.Write(ctx, c, "hi") if err != nil { @@ -71,7 +71,7 @@ func ExampleCloseStatus() { if err != nil { log.Fatal(err) } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() _, _, err = c.Reader(ctx) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { @@ -88,7 +88,7 @@ func Example_writeOnly() { log.Println(err) return } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) defer cancel() @@ -135,64 +135,37 @@ func Example_crossOrigin() { log.Fatal(err) } -// This example demonstrates how to create a WebSocket server -// that gracefully exits when sent a signal. -// -// It starts a WebSocket server that keeps every connection open -// for 10 seconds. -// If you CTRL+C while a connection is open, it will wait at most 30s -// for all connections to terminate before shutting down. -// func ExampleGrace() { -// fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// c, err := websocket.Accept(w, r, nil) -// if err != nil { -// log.Println(err) -// return -// } -// defer c.Close(websocket.StatusInternalError, "the sky is falling") -// -// ctx := c.CloseRead(r.Context()) -// select { -// case <-ctx.Done(): -// case <-time.After(time.Second * 10): -// } -// -// c.Close(websocket.StatusNormalClosure, "") -// }) -// -// var g websocket.Grace -// s := &http.Server{ -// Handler: g.Handler(fn), -// ReadTimeout: time.Second * 15, -// WriteTimeout: time.Second * 15, -// } -// -// errc := make(chan error, 1) -// go func() { -// errc <- s.ListenAndServe() -// }() -// -// sigs := make(chan os.Signal, 1) -// signal.Notify(sigs, os.Interrupt) -// select { -// case err := <-errc: -// log.Printf("failed to listen and serve: %v", err) -// case sig := <-sigs: -// log.Printf("terminating: %v", sig) -// } -// -// ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) -// defer cancel() -// s.Shutdown(ctx) -// g.Shutdown(ctx) -// } +func ExampleConn_Ping() { + // Dials a server and pings it 5 times. + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) + if err != nil { + log.Fatal(err) + } + defer c.CloseNow() + + // Required to read the Pongs from the server. + ctx = c.CloseRead(ctx) + + for i := 0; i < 5; i++ { + err = c.Ping(ctx) + if err != nil { + log.Fatal(err) + } + } + + c.Close(websocket.StatusNormalClosure, "") +} // This example demonstrates full stack chat with an automated test. func Example_fullStackChat() { - // https://github.com/nhooyr/websocket/tree/master/examples/chat + // https://github.com/nhooyr/websocket/tree/master/internal/examples/chat } // This example demonstrates a echo server. func Example_echo() { - // https://github.com/nhooyr/websocket/tree/master/examples/echo + // https://github.com/nhooyr/websocket/tree/master/internal/examples/echo } diff --git a/export_test.go b/export_test.go index 88b82c9f..a644d8f0 100644 --- a/export_test.go +++ b/export_test.go @@ -1,10 +1,17 @@ +//go:build !js // +build !js package websocket +import ( + "net" + + "nhooyr.io/websocket/internal/util" +) + func (c *Conn) RecordBytesWritten() *int { var bytesWritten int - c.bw.Reset(writerFunc(func(p []byte) (int, error) { + c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) { bytesWritten += len(p) return c.rwc.Write(p) })) @@ -13,10 +20,19 @@ func (c *Conn) RecordBytesWritten() *int { func (c *Conn) RecordBytesRead() *int { var bytesRead int - c.br.Reset(readerFunc(func(p []byte) (int, error) { + c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) { n, err := c.rwc.Read(p) bytesRead += n return n, err })) return &bytesRead } + +var ErrClosed = net.ErrClosed + +var ExportedDial = dial +var SecWebSocketAccept = secWebSocketAccept +var SecWebSocketKey = secWebSocketKey +var VerifyServerResponse = verifyServerResponse + +var CompressionModeOpts = CompressionMode.opts diff --git a/frame.go b/frame.go index 2a036f94..351632fd 100644 --- a/frame.go +++ b/frame.go @@ -1,3 +1,5 @@ +//go:build !js + package websocket import ( diff --git a/frame_test.go b/frame_test.go index 76826248..e697e198 100644 --- a/frame_test.go +++ b/frame_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -11,10 +12,6 @@ import ( "strconv" "testing" "time" - _ "unsafe" - - "github.com/gobwas/ws" - _ "github.com/gorilla/websocket" "nhooyr.io/websocket/internal/test/assert" ) @@ -54,7 +51,7 @@ func TestHeader(t *testing.T) { r := rand.New(rand.NewSource(time.Now().UnixNano())) randBool := func() bool { - return r.Intn(1) == 0 + return r.Intn(2) == 0 } for i := 0; i < 10000; i++ { @@ -66,9 +63,11 @@ func TestHeader(t *testing.T) { opcode: opcode(r.Intn(16)), masked: randBool(), - maskKey: r.Uint32(), payloadLength: r.Int63(), } + if h.masked { + h.maskKey = r.Uint32() + } testHeader(t, h) } @@ -106,87 +105,3 @@ func Test_mask(t *testing.T) { expKey32 := bits.RotateLeft32(key32, -8) assert.Equal(t, "key32", expKey32, gotKey32) } - -func basicMask(maskKey [4]byte, pos int, b []byte) int { - for i := range b { - b[i] ^= maskKey[pos&3] - pos++ - } - return pos & 3 -} - -//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes -func gorillaMaskBytes(key [4]byte, pos int, b []byte) int - -func Benchmark_mask(b *testing.B) { - sizes := []int{ - 2, - 3, - 4, - 8, - 16, - 32, - 128, - 512, - 4096, - 16384, - } - - fns := []struct { - name string - fn func(b *testing.B, key [4]byte, p []byte) - }{ - { - name: "basic", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - basicMask(key, 0, p) - } - }, - }, - - { - name: "nhooyr", - fn: func(b *testing.B, key [4]byte, p []byte) { - key32 := binary.LittleEndian.Uint32(key[:]) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - mask(key32, p) - } - }, - }, - { - name: "gorilla", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - gorillaMaskBytes(key, 0, p) - } - }, - }, - { - name: "gobwas", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - ws.Cipher(p, key, 0) - } - }, - }, - } - - key := [4]byte{1, 2, 3, 4} - - for _, size := range sizes { - p := make([]byte, size) - - b.Run(strconv.Itoa(size), func(b *testing.B) { - for _, fn := range fns { - b.Run(fn.name, func(b *testing.B) { - b.SetBytes(int64(size)) - - fn.fn(b, key, p) - }) - } - }) - } -} diff --git a/go.mod b/go.mod index c5f1a20f..715a9f7a 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,3 @@ module nhooyr.io/websocket -go 1.13 - -require ( - github.com/gin-gonic/gin v1.6.3 - github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect - github.com/gobwas/pool v0.2.0 // indirect - github.com/gobwas/ws v1.0.2 - github.com/golang/protobuf v1.3.5 - github.com/google/go-cmp v0.4.0 - github.com/gorilla/websocket v1.4.1 - github.com/klauspost/compress v1.10.3 - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 -) +go 1.19 diff --git a/go.sum b/go.sum deleted file mode 100644 index 155c3013..00000000 --- a/go.sum +++ /dev/null @@ -1,64 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= -github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= -github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= -github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= -github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= -github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= -github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= -github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls= -github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= -github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= -github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/examples/README.md b/internal/examples/README.md similarity index 100% rename from examples/README.md rename to internal/examples/README.md diff --git a/examples/chat/README.md b/internal/examples/chat/README.md similarity index 100% rename from examples/chat/README.md rename to internal/examples/chat/README.md diff --git a/examples/chat/chat.go b/internal/examples/chat/chat.go similarity index 88% rename from examples/chat/chat.go rename to internal/examples/chat/chat.go index 532e50f5..8b1e30c1 100644 --- a/examples/chat/chat.go +++ b/internal/examples/chat/chat.go @@ -3,8 +3,9 @@ package main import ( "context" "errors" - "io/ioutil" + "io" "log" + "net" "net/http" "sync" "time" @@ -69,14 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, nil) - if err != nil { - cs.logf("%v", err) - return - } - defer c.Close(websocket.StatusInternalError, "") - - err = cs.subscribe(r.Context(), c) + err := cs.subscribe(r.Context(), w, r) if errors.Is(err, context.Canceled) { return } @@ -98,7 +92,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { return } body := http.MaxBytesReader(w, r.Body, 8192) - msg, err := ioutil.ReadAll(body) + msg, err := io.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) return @@ -117,18 +111,39 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. -func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - +func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + var mu sync.Mutex + var c *websocket.Conn + var closed bool s := &subscriber{ msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { - c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + mu.Lock() + defer mu.Unlock() + closed = true + if c != nil { + c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + } }, } cs.addSubscriber(s) defer cs.deleteSubscriber(s) + c2, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + mu.Lock() + if closed { + mu.Unlock() + return net.ErrClosed + } + c = c2 + mu.Unlock() + defer c.CloseNow() + + ctx = c.CloseRead(ctx) + for { select { case msg := <-s.msgs: diff --git a/examples/chat/chat_test.go b/internal/examples/chat/chat_test.go similarity index 100% rename from examples/chat/chat_test.go rename to internal/examples/chat/chat_test.go diff --git a/examples/chat/index.css b/internal/examples/chat/index.css similarity index 86% rename from examples/chat/index.css rename to internal/examples/chat/index.css index 73a8e0f3..ce27c378 100644 --- a/examples/chat/index.css +++ b/internal/examples/chat/index.css @@ -54,7 +54,7 @@ body { margin: 0 0 0 10px; } -#publish-form input[type="text"] { +#publish-form input[type='text'] { flex-grow: 1; -moz-appearance: none; @@ -64,7 +64,7 @@ body { border: 1px solid #ccc; } -#publish-form input[type="submit"] { +#publish-form input[type='submit'] { color: white; background-color: black; border-radius: 5px; @@ -72,10 +72,10 @@ body { border: none; } -#publish-form input[type="submit"]:hover { +#publish-form input[type='submit']:hover { background-color: red; } -#publish-form input[type="submit"]:active { +#publish-form input[type='submit']:active { background-color: red; } diff --git a/examples/chat/index.html b/internal/examples/chat/index.html similarity index 98% rename from examples/chat/index.html rename to internal/examples/chat/index.html index 76ae8370..64edd286 100644 --- a/examples/chat/index.html +++ b/internal/examples/chat/index.html @@ -1,4 +1,4 @@ - + diff --git a/examples/chat/index.js b/internal/examples/chat/index.js similarity index 65% rename from examples/chat/index.js rename to internal/examples/chat/index.js index 5868e7ca..2efca013 100644 --- a/examples/chat/index.js +++ b/internal/examples/chat/index.js @@ -6,21 +6,21 @@ function dial() { const conn = new WebSocket(`ws://${location.host}/subscribe`) - conn.addEventListener("close", ev => { + conn.addEventListener('close', ev => { appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) if (ev.code !== 1001) { - appendLog("Reconnecting in 1s", true) + appendLog('Reconnecting in 1s', true) setTimeout(dial, 1000) } }) - conn.addEventListener("open", ev => { - console.info("websocket connected") + conn.addEventListener('open', ev => { + console.info('websocket connected') }) // This is where we handle messages received. - conn.addEventListener("message", ev => { - if (typeof ev.data !== "string") { - console.error("unexpected message type", typeof ev.data) + conn.addEventListener('message', ev => { + if (typeof ev.data !== 'string') { + console.error('unexpected message type', typeof ev.data) return } const p = appendLog(ev.data) @@ -32,38 +32,38 @@ } dial() - const messageLog = document.getElementById("message-log") - const publishForm = document.getElementById("publish-form") - const messageInput = document.getElementById("message-input") + const messageLog = document.getElementById('message-log') + const publishForm = document.getElementById('publish-form') + const messageInput = document.getElementById('message-input') // appendLog appends the passed text to messageLog. function appendLog(text, error) { - const p = document.createElement("p") + const p = document.createElement('p') // Adding a timestamp to each message makes the log easier to read. p.innerText = `${new Date().toLocaleTimeString()}: ${text}` if (error) { - p.style.color = "red" - p.style.fontStyle = "bold" + p.style.color = 'red' + p.style.fontStyle = 'bold' } messageLog.append(p) return p } - appendLog("Submit a message to get started!") + appendLog('Submit a message to get started!') // onsubmit publishes the message from the user when the form is submitted. publishForm.onsubmit = async ev => { ev.preventDefault() const msg = messageInput.value - if (msg === "") { + if (msg === '') { return } - messageInput.value = "" + messageInput.value = '' expectingMessage = true try { - const resp = await fetch("/publish", { - method: "POST", + const resp = await fetch('/publish', { + method: 'POST', body: msg, }) if (resp.status !== 202) { diff --git a/examples/chat/main.go b/internal/examples/chat/main.go similarity index 100% rename from examples/chat/main.go rename to internal/examples/chat/main.go diff --git a/examples/echo/README.md b/internal/examples/echo/README.md similarity index 100% rename from examples/echo/README.md rename to internal/examples/echo/README.md diff --git a/examples/echo/main.go b/internal/examples/echo/main.go similarity index 100% rename from examples/echo/main.go rename to internal/examples/echo/main.go diff --git a/examples/echo/server.go b/internal/examples/echo/server.go similarity index 95% rename from examples/echo/server.go rename to internal/examples/echo/server.go index e9f70f03..246ad582 100644 --- a/examples/echo/server.go +++ b/internal/examples/echo/server.go @@ -28,7 +28,7 @@ func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.logf("%v", err) return } - defer c.Close(websocket.StatusInternalError, "the sky is falling") + defer c.CloseNow() if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") diff --git a/examples/echo/server_test.go b/internal/examples/echo/server_test.go similarity index 100% rename from examples/echo/server_test.go rename to internal/examples/echo/server_test.go diff --git a/internal/examples/go.mod b/internal/examples/go.mod new file mode 100644 index 00000000..c98b81ce --- /dev/null +++ b/internal/examples/go.mod @@ -0,0 +1,10 @@ +module nhooyr.io/websocket/examples + +go 1.19 + +replace nhooyr.io/websocket => ../.. + +require ( + golang.org/x/time v0.3.0 + nhooyr.io/websocket v0.0.0-00010101000000-000000000000 +) diff --git a/internal/examples/go.sum b/internal/examples/go.sum new file mode 100644 index 00000000..f8a07e82 --- /dev/null +++ b/internal/examples/go.sum @@ -0,0 +1,2 @@ +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go index 6eaf7fc3..1b90cc9f 100644 --- a/internal/test/assert/assert.go +++ b/internal/test/assert/assert.go @@ -1,29 +1,19 @@ package assert import ( + "errors" "fmt" "reflect" "strings" "testing" - - "github.com/golang/protobuf/proto" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" ) -// Diff returns a human readable diff between v1 and v2 -func Diff(v1, v2 interface{}) string { - return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { - return true - }), cmp.Comparer(proto.Equal)) -} - // Equal asserts exp == act. -func Equal(t testing.TB, name string, exp, act interface{}) { +func Equal(t testing.TB, name string, exp, got interface{}) { t.Helper() - if diff := Diff(exp, act); diff != "" { - t.Fatalf("unexpected %v: %v", name, diff) + if !reflect.DeepEqual(exp, got) { + t.Fatalf("unexpected %v: expected %#v but got %#v", name, exp, got) } } @@ -54,3 +44,12 @@ func Contains(t testing.TB, v interface{}, sub string) { t.Fatalf("expected %q to contain %q", s, sub) } } + +// ErrorIs asserts errors.Is(got, exp) +func ErrorIs(t testing.TB, exp, got error) { + t.Helper() + + if !errors.Is(got, exp) { + t.Fatalf("expected %v but got %v", exp, got) + } +} diff --git a/internal/test/wstest/echo.go b/internal/test/wstest/echo.go index 8f4e47c8..dc21a8f0 100644 --- a/internal/test/wstest/echo.go +++ b/internal/test/wstest/echo.go @@ -8,7 +8,6 @@ import ( "time" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" ) @@ -21,7 +20,7 @@ func EchoLoop(ctx context.Context, c *websocket.Conn) error { c.SetReadLimit(1 << 30) - ctx, cancel := context.WithTimeout(ctx, time.Minute) + ctx, cancel := context.WithTimeout(ctx, time.Minute*5) defer cancel() b := make([]byte, 32<<10) @@ -76,7 +75,7 @@ func Echo(ctx context.Context, c *websocket.Conn, max int) error { } if !bytes.Equal(msg, act) { - return fmt.Errorf("unexpected msg read: %v", assert.Diff(msg, act)) + return fmt.Errorf("unexpected msg read: %#v", act) } return nil diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index 1534f316..8e1deb47 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package wstest @@ -24,7 +25,8 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) if dialOpts == nil { dialOpts = &websocket.DialOptions{} } - dialOpts = &*dialOpts + _dialOpts := *dialOpts + dialOpts = &_dialOpts dialOpts.HTTPClient = &http.Client{ Transport: tt, } diff --git a/internal/test/xrand/xrand.go b/internal/test/xrand/xrand.go index 8de1ede8..9bfb39ce 100644 --- a/internal/test/xrand/xrand.go +++ b/internal/test/xrand/xrand.go @@ -2,6 +2,7 @@ package xrand import ( "crypto/rand" + "encoding/base64" "fmt" "math/big" "strings" @@ -45,3 +46,8 @@ func Int(max int) int { } return int(x.Int64()) } + +// Base64 returns a randomly generated base64 string of length n. +func Base64(n int) string { + return base64.StdEncoding.EncodeToString(Bytes(n)) +} diff --git a/internal/thirdparty/doc.go b/internal/thirdparty/doc.go new file mode 100644 index 00000000..e756d09f --- /dev/null +++ b/internal/thirdparty/doc.go @@ -0,0 +1,2 @@ +// Package thirdparty contains third party benchmarks and tests. +package thirdparty diff --git a/internal/thirdparty/frame_test.go b/internal/thirdparty/frame_test.go new file mode 100644 index 00000000..1a0ed125 --- /dev/null +++ b/internal/thirdparty/frame_test.go @@ -0,0 +1,100 @@ +package thirdparty + +import ( + "encoding/binary" + "strconv" + "testing" + _ "unsafe" + + "github.com/gobwas/ws" + _ "github.com/gorilla/websocket" + + _ "nhooyr.io/websocket" +) + +func basicMask(maskKey [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= maskKey[pos&3] + pos++ + } + return pos & 3 +} + +//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes +func gorillaMaskBytes(key [4]byte, pos int, b []byte) int + +//go:linkname mask nhooyr.io/websocket.mask +func mask(key32 uint32, b []byte) int + +func Benchmark_mask(b *testing.B) { + sizes := []int{ + 2, + 3, + 4, + 8, + 16, + 32, + 128, + 512, + 4096, + 16384, + } + + fns := []struct { + name string + fn func(b *testing.B, key [4]byte, p []byte) + }{ + { + name: "basic", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + basicMask(key, 0, p) + } + }, + }, + + { + name: "nhooyr", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.LittleEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + mask(key32, p) + } + }, + }, + { + name: "gorilla", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + gorillaMaskBytes(key, 0, p) + } + }, + }, + { + name: "gobwas", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + ws.Cipher(p, key, 0) + } + }, + }, + } + + key := [4]byte{1, 2, 3, 4} + + for _, size := range sizes { + p := make([]byte, size) + + b.Run(strconv.Itoa(size), func(b *testing.B) { + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + b.SetBytes(int64(size)) + + fn.fn(b, key, p) + }) + } + }) + } +} diff --git a/internal/thirdparty/gin_test.go b/internal/thirdparty/gin_test.go new file mode 100644 index 00000000..6d59578d --- /dev/null +++ b/internal/thirdparty/gin_test.go @@ -0,0 +1,75 @@ +package thirdparty + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/assert" + "nhooyr.io/websocket/internal/test/wstest" + "nhooyr.io/websocket/wsjson" +) + +func TestGin(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.GET("/", func(ginCtx *gin.Context) { + err := echoServer(ginCtx.Writer, ginCtx.Request, nil) + if err != nil { + t.Error(err) + } + }) + + s := httptest.NewServer(r) + defer s.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + c, _, err := websocket.Dial(ctx, s.URL, nil) + assert.Success(t, err) + defer c.Close(websocket.StatusInternalError, "") + + err = wsjson.Write(ctx, c, "hello") + assert.Success(t, err) + + var v interface{} + err = wsjson.Read(ctx, c, &v) + assert.Success(t, err) + assert.Equal(t, "read msg", "hello", v) + + err = c.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) +} + +func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) { + defer errd.Wrap(&err, "echo server failed") + + c, err := websocket.Accept(w, r, opts) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + err = wstest.EchoLoop(r.Context(), c) + return assertCloseStatus(websocket.StatusNormalClosure, err) +} + +func assertCloseStatus(exp websocket.StatusCode, err error) error { + if websocket.CloseStatus(err) == -1 { + return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) + } + if websocket.CloseStatus(err) != exp { + return fmt.Errorf("expected close status %v but got %v", exp, err) + } + return nil +} diff --git a/internal/thirdparty/go.mod b/internal/thirdparty/go.mod new file mode 100644 index 00000000..10eb45c1 --- /dev/null +++ b/internal/thirdparty/go.mod @@ -0,0 +1,41 @@ +module nhooyr.io/websocket/internal/thirdparty + +go 1.19 + +replace nhooyr.io/websocket => ../.. + +require ( + github.com/gin-gonic/gin v1.9.1 + github.com/gobwas/ws v1.3.0 + github.com/gorilla/websocket v1.5.0 + nhooyr.io/websocket v0.0.0-00010101000000-000000000000 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/crypto v0.9.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/internal/thirdparty/go.sum b/internal/thirdparty/go.sum new file mode 100644 index 00000000..a9424b8d --- /dev/null +++ b/internal/thirdparty/go.sum @@ -0,0 +1,93 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= +github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/util/util.go b/internal/util/util.go new file mode 100644 index 00000000..aa210703 --- /dev/null +++ b/internal/util/util.go @@ -0,0 +1,15 @@ +package util + +// WriterFunc is used to implement one off io.Writers. +type WriterFunc func(p []byte) (int, error) + +func (f WriterFunc) Write(p []byte) (int, error) { + return f(p) +} + +// ReaderFunc is used to implement one off io.Readers. +type ReaderFunc func(p []byte) (int, error) + +func (f ReaderFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/internal/wsjs/wsjs_js.go b/internal/wsjs/wsjs_js.go index 26ffb456..11eb59cb 100644 --- a/internal/wsjs/wsjs_js.go +++ b/internal/wsjs/wsjs_js.go @@ -1,3 +1,4 @@ +//go:build js // +build js // Package wsjs implements typed access to the browser javascript WebSocket API. @@ -118,8 +119,6 @@ func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { Data: data, } fn(me) - - return }) } diff --git a/internal/xsync/go.go b/internal/xsync/go.go index 7a61f27f..5229b12a 100644 --- a/internal/xsync/go.go +++ b/internal/xsync/go.go @@ -2,6 +2,7 @@ package xsync import ( "fmt" + "runtime/debug" ) // Go allows running a function in another goroutine @@ -13,7 +14,7 @@ func Go(fn func() error) <-chan error { r := recover() if r != nil { select { - case errs <- fmt.Errorf("panic in go fn: %v", r): + case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()): default: } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 00000000..2b93bb18 --- /dev/null +++ b/main_test.go @@ -0,0 +1,30 @@ +package websocket_test + +import ( + "fmt" + "os" + "runtime" + "testing" +) + +func goroutineStacks() []byte { + buf := make([]byte, 512) + for { + m := runtime.Stack(buf, true) + if m < len(buf) { + return buf[:m] + } + buf = make([]byte, len(buf)*2) + } +} + +func TestMain(m *testing.M) { + code := m.Run() + if runtime.GOOS != "js" && runtime.NumGoroutine() != 1 || + runtime.GOOS == "js" && runtime.NumGoroutine() != 2 { + fmt.Fprintf(os.Stderr, "goroutine leak detected, expected 1 but got %d goroutines\n", runtime.NumGoroutine()) + fmt.Fprintf(os.Stderr, "%s\n", goroutineStacks()) + os.Exit(1) + } + os.Exit(code) +} diff --git a/make.sh b/make.sh new file mode 100755 index 00000000..170d00a8 --- /dev/null +++ b/make.sh @@ -0,0 +1,12 @@ +#!/bin/sh +set -eu +cd -- "$(dirname "$0")" + +echo "=== fmt.sh" +./ci/fmt.sh +echo "=== lint.sh" +./ci/lint.sh +echo "=== test.sh" +./ci/test.sh "$@" +echo "=== bench.sh" +./ci/bench.sh diff --git a/netconn.go b/netconn.go index 64aadf0b..1667f45c 100644 --- a/netconn.go +++ b/netconn.go @@ -6,7 +6,7 @@ import ( "io" "math" "net" - "sync" + "sync/atomic" "time" ) @@ -28,30 +28,64 @@ import ( // // Close will close the *websocket.Conn with StatusNormalClosure. // -// When a deadline is hit, the connection will be closed. This is -// different from most net.Conn implementations where only the -// reading/writing goroutines are interrupted but the connection is kept alive. +// When a deadline is hit and there is an active read or write goroutine, the +// connection will be closed. This is different from most net.Conn implementations +// where only the reading/writing goroutines are interrupted but the connection +// is kept alive. // -// The Addr methods will return a mock net.Addr that returns "websocket" for Network -// and "websocket/unknown-addr" for String. +// The Addr methods will return the real addresses for connections obtained +// from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr +// will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for +// String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the +// full net.Conn to us. +// +// When running as WASM, the Addr methods will always return the mock address described above. // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. +// +// Furthermore, the ReadLimit is set to -1 to disable it. func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { + c.SetReadLimit(-1) + nc := &netConn{ c: c, msgType: msgType, + readMu: newMu(c), + writeMu: newMu(c), } - var cancel context.CancelFunc - nc.writeContext, cancel = context.WithCancel(ctx) - nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + nc.writeCtx, nc.writeCancel = context.WithCancel(ctx) + nc.readCtx, nc.readCancel = context.WithCancel(ctx) + + nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { + if !nc.writeMu.tryLock() { + // If the lock cannot be acquired, then there is an + // active write goroutine and so we should cancel the context. + nc.writeCancel() + return + } + defer nc.writeMu.unlock() + + // Prevents future writes from writing until the deadline is reset. + atomic.StoreInt64(&nc.writeExpired, 1) + }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } - nc.readContext, cancel = context.WithCancel(ctx) - nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + nc.readTimer = time.AfterFunc(math.MaxInt64, func() { + if !nc.readMu.tryLock() { + // If the lock cannot be acquired, then there is an + // active read goroutine and so we should cancel the context. + nc.readCancel() + return + } + defer nc.readMu.unlock() + + // Prevents future reads from reading until the deadline is reset. + atomic.StoreInt64(&nc.readExpired, 1) + }) if !nc.readTimer.Stop() { <-nc.readTimer.C } @@ -64,59 +98,91 @@ type netConn struct { msgType MessageType writeTimer *time.Timer - writeContext context.Context + writeMu *mu + writeExpired int64 + writeCtx context.Context + writeCancel context.CancelFunc readTimer *time.Timer - readContext context.Context - - readMu sync.Mutex - eofed bool - reader io.Reader + readMu *mu + readExpired int64 + readCtx context.Context + readCancel context.CancelFunc + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} -func (c *netConn) Close() error { - return c.c.Close(StatusNormalClosure, "") +func (nc *netConn) Close() error { + nc.writeTimer.Stop() + nc.writeCancel() + nc.readTimer.Stop() + nc.readCancel() + return nc.c.Close(StatusNormalClosure, "") } -func (c *netConn) Write(p []byte) (int, error) { - err := c.c.Write(c.writeContext, c.msgType, p) +func (nc *netConn) Write(p []byte) (int, error) { + nc.writeMu.forceLock() + defer nc.writeMu.unlock() + + if atomic.LoadInt64(&nc.writeExpired) == 1 { + return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) + } + + err := nc.c.Write(nc.writeCtx, nc.msgType, p) if err != nil { return 0, err } return len(p), nil } -func (c *netConn) Read(p []byte) (int, error) { - c.readMu.Lock() - defer c.readMu.Unlock() +func (nc *netConn) Read(p []byte) (int, error) { + nc.readMu.forceLock() + defer nc.readMu.unlock() + + for { + n, err := nc.read(p) + if err != nil { + return n, err + } + if n == 0 { + continue + } + return n, nil + } +} + +func (nc *netConn) read(p []byte) (int, error) { + if atomic.LoadInt64(&nc.readExpired) == 1 { + return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) + } - if c.eofed { + if nc.readEOFed { return 0, io.EOF } - if c.reader == nil { - typ, r, err := c.c.Reader(c.readContext) + if nc.reader == nil { + typ, r, err := nc.c.Reader(nc.readCtx) if err != nil { switch CloseStatus(err) { case StatusNormalClosure, StatusGoingAway: - c.eofed = true + nc.readEOFed = true return 0, io.EOF } return 0, err } - if typ != c.msgType { - err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) - c.c.Close(StatusUnsupportedData, err.Error()) + if typ != nc.msgType { + err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ) + nc.c.Close(StatusUnsupportedData, err.Error()) return 0, err } - c.reader = r + nc.reader = r } - n, err := c.reader.Read(p) + n, err := nc.reader.Read(p) if err == io.EOF { - c.reader = nil + nc.reader = nil err = nil } return n, err @@ -133,34 +199,36 @@ func (a websocketAddr) String() string { return "websocket/unknown-addr" } -func (c *netConn) RemoteAddr() net.Addr { - return websocketAddr{} -} - -func (c *netConn) LocalAddr() net.Addr { - return websocketAddr{} -} - -func (c *netConn) SetDeadline(t time.Time) error { - c.SetWriteDeadline(t) - c.SetReadDeadline(t) +func (nc *netConn) SetDeadline(t time.Time) error { + nc.SetWriteDeadline(t) + nc.SetReadDeadline(t) return nil } -func (c *netConn) SetWriteDeadline(t time.Time) error { +func (nc *netConn) SetWriteDeadline(t time.Time) error { + atomic.StoreInt64(&nc.writeExpired, 0) if t.IsZero() { - c.writeTimer.Stop() + nc.writeTimer.Stop() } else { - c.writeTimer.Reset(t.Sub(time.Now())) + dur := time.Until(t) + if dur <= 0 { + dur = 1 + } + nc.writeTimer.Reset(dur) } return nil } -func (c *netConn) SetReadDeadline(t time.Time) error { +func (nc *netConn) SetReadDeadline(t time.Time) error { + atomic.StoreInt64(&nc.readExpired, 0) if t.IsZero() { - c.readTimer.Stop() + nc.readTimer.Stop() } else { - c.readTimer.Reset(t.Sub(time.Now())) + dur := time.Until(t) + if dur <= 0 { + dur = 1 + } + nc.readTimer.Reset(dur) } return nil } diff --git a/netconn_js.go b/netconn_js.go new file mode 100644 index 00000000..ccc8c89f --- /dev/null +++ b/netconn_js.go @@ -0,0 +1,11 @@ +package websocket + +import "net" + +func (nc *netConn) RemoteAddr() net.Addr { + return websocketAddr{} +} + +func (nc *netConn) LocalAddr() net.Addr { + return websocketAddr{} +} diff --git a/netconn_notjs.go b/netconn_notjs.go new file mode 100644 index 00000000..f3eb0d66 --- /dev/null +++ b/netconn_notjs.go @@ -0,0 +1,20 @@ +//go:build !js +// +build !js + +package websocket + +import "net" + +func (nc *netConn) RemoteAddr() net.Addr { + if unc, ok := nc.c.rwc.(net.Conn); ok { + return unc.RemoteAddr() + } + return websocketAddr{} +} + +func (nc *netConn) LocalAddr() net.Addr { + if unc, ok := nc.c.rwc.(net.Conn); ok { + return unc.LocalAddr() + } + return websocketAddr{} +} diff --git a/read.go b/read.go index 89a00988..8742842e 100644 --- a/read.go +++ b/read.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -8,11 +9,12 @@ import ( "errors" "fmt" "io" - "io/ioutil" + "net" "strings" "time" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/util" "nhooyr.io/websocket/internal/xsync" ) @@ -26,6 +28,11 @@ import ( // Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. +// +// If you need a separate timeout on the Reader call and the Read itself, +// use time.AfterFunc to cancel the context passed in. +// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 +// Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return c.reader(ctx) } @@ -38,7 +45,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { return 0, nil, err } - b, err := ioutil.ReadAll(r) + b, err := io.ReadAll(r) return typ, b, err } @@ -55,10 +62,16 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // frames are responded to. This means c.Ping and c.Close will still work as expected. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) + + c.wg.Add(1) go func() { + defer c.CloseNow() + defer c.wg.Done() defer cancel() - c.Reader(ctx) - c.Close(StatusPolicyViolation, "unexpected data message") + _, _, err := c.Reader(ctx) + if err == nil { + c.Close(StatusPolicyViolation, "unexpected data message") + } }() return ctx } @@ -69,10 +82,16 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. +// +// Set to -1 to disable. func (c *Conn) SetReadLimit(n int64) { - // We add read one more byte than the limit in case - // there is a fin frame that needs to be read. - c.msgReader.limitReader.limit.Store(n + 1) + if n >= 0 { + // We read one more byte than the limit in case + // there is a fin frame that needs to be read. + n++ + } + + c.msgReader.limitReader.limit.Store(n) } const defaultReadLimit = 32768 @@ -90,13 +109,20 @@ func newMsgReader(c *Conn) *msgReader { func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { + if mr.dict == nil { + mr.dict = &slidingWindow{} + } mr.dict.init(32768) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) } - mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + if mr.flateContextTakeover() { + mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + } else { + mr.flateReader = getFlateReader(mr.flateBufio, nil) + } mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } @@ -111,7 +137,10 @@ func (mr *msgReader) putFlateReader() { func (mr *msgReader) close() { mr.c.readMu.forceLock() mr.putFlateReader() - mr.dict.close() + if mr.dict != nil { + mr.dict.close() + mr.dict = nil + } if mr.flateBufio != nil { putBufioReader(mr.flateBufio) } @@ -181,7 +210,7 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: - return header{}, c.closeErr + return header{}, net.ErrClosed case c.readTimeout <- ctx: } @@ -189,7 +218,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { if err != nil { select { case <-c.closed: - return header{}, c.closeErr + return header{}, net.ErrClosed case <-ctx.Done(): return header{}, ctx.Err() default: @@ -200,7 +229,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: - return header{}, c.closeErr + return header{}, net.ErrClosed case c.readTimeout <- context.Background(): } @@ -210,7 +239,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: - return 0, c.closeErr + return 0, net.ErrClosed case c.readTimeout <- ctx: } @@ -218,7 +247,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { if err != nil { select { case <-c.closed: - return n, c.closeErr + return n, net.ErrClosed case <-ctx.Done(): return n, ctx.Err() default: @@ -230,7 +259,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: - return n, c.closeErr + return n, net.ErrClosed case c.readTimeout <- context.Background(): } @@ -337,14 +366,14 @@ type msgReader struct { flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader - dict slidingWindow + dict *slidingWindow fin bool payloadLength int64 maskKey uint32 - // readerFunc(mr.Read) to avoid continuous allocations. - readFunc readerFunc + // util.ReaderFunc(mr.Read) to avoid continuous allocations. + readFunc util.ReaderFunc } func (mr *msgReader) reset(ctx context.Context, h header) { @@ -453,7 +482,11 @@ func (lr *limitReader) reset(r io.Reader) { } func (lr *limitReader) Read(p []byte) (int, error) { - if lr.n <= 0 { + if lr.n < 0 { + return lr.r.Read(p) + } + + if lr.n == 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err @@ -464,11 +497,8 @@ func (lr *limitReader) Read(p []byte) (int, error) { } n, err := lr.r.Read(p) lr.n -= int64(n) + if lr.n < 0 { + lr.n = 0 + } return n, err } - -type readerFunc func(p []byte) (int, error) - -func (f readerFunc) Read(p []byte) (int, error) { - return f(p) -} diff --git a/write.go b/write.go index 2210cf81..7b1152ce 100644 --- a/write.go +++ b/write.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package websocket @@ -10,11 +11,13 @@ import ( "errors" "fmt" "io" + "net" "time" - "github.com/klauspost/compress/flate" + "compress/flate" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/util" ) // Writer returns a writer bounded by the context that will write @@ -36,7 +39,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // // See the Writer method if you want to stream a message. // -// If compression is disabled or the threshold is not met, then it +// If compression is disabled or the compression threshold is not met, then it // will write the message in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) @@ -47,41 +50,22 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { } type msgWriter struct { - mw *msgWriterState - closed bool -} - -func (mw *msgWriter) Write(p []byte) (int, error) { - if mw.closed { - return 0, errors.New("cannot use closed writer") - } - return mw.mw.Write(p) -} - -func (mw *msgWriter) Close() error { - if mw.closed { - return errors.New("cannot use closed writer") - } - mw.closed = true - return mw.mw.Close() -} - -type msgWriterState struct { c *Conn mu *mu writeMu *mu + closed bool ctx context.Context opcode opcode flate bool - trimWriter *trimLastFourBytesWriter - dict slidingWindow + trimWriter *trimLastFourBytesWriter + flateWriter *flate.Writer } -func newMsgWriterState(c *Conn) *msgWriterState { - mw := &msgWriterState{ +func newMsgWriter(c *Conn) *msgWriter { + mw := &msgWriter{ c: c, mu: newMu(c), writeMu: newMu(c), @@ -89,18 +73,20 @@ func newMsgWriterState(c *Conn) *msgWriterState { return mw } -func (mw *msgWriterState) ensureFlate() { +func (mw *msgWriter) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ - w: writerFunc(mw.write), + w: util.WriterFunc(mw.write), } } - mw.dict.init(8192) + if mw.flateWriter == nil { + mw.flateWriter = getFlateWriter(mw.trimWriter) + } mw.flate = true } -func (mw *msgWriterState) flateContextTakeover() bool { +func (mw *msgWriter) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } @@ -108,14 +94,11 @@ func (mw *msgWriterState) flateContextTakeover() bool { } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.msgWriterState.reset(ctx, typ) + err := c.msgWriter.reset(ctx, typ) if err != nil { return nil, err } - return &msgWriter{ - mw: c.msgWriterState, - closed: false, - }, nil + return c.msgWriter, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { @@ -125,8 +108,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error } if !c.flate() { - defer c.msgWriterState.mu.unlock() - return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) + defer c.msgWriter.mu.unlock() + return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } n, err := mw.Write(p) @@ -138,7 +121,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return n, err } -func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { +func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.lock(ctx) if err != nil { return err @@ -147,20 +130,32 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false + mw.closed = false mw.trimWriter.reset() return nil } +func (mw *msgWriter) putFlateWriter() { + if mw.flateWriter != nil { + putFlateWriter(mw.flateWriter) + mw.flateWriter = nil + } +} + // Write writes the given bytes to the WebSocket connection. -func (mw *msgWriterState) Write(p []byte) (_ int, err error) { +func (mw *msgWriter) Write(p []byte) (_ int, err error) { err = mw.writeMu.lock(mw.ctx) if err != nil { return 0, fmt.Errorf("failed to write: %w", err) } defer mw.writeMu.unlock() + if mw.closed { + return 0, errors.New("cannot use closed writer") + } + defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) @@ -177,18 +172,13 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) { } if mw.flate { - err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) - if err != nil { - return 0, err - } - mw.dict.write(p) - return len(p), nil + return mw.flateWriter.Write(p) } return mw.write(p) } -func (mw *msgWriterState) write(p []byte) (int, error) { +func (mw *msgWriter) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) @@ -198,7 +188,7 @@ func (mw *msgWriterState) write(p []byte) (int, error) { } // Close flushes the frame to the connection. -func (mw *msgWriterState) Close() (err error) { +func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") err = mw.writeMu.lock(mw.ctx) @@ -207,26 +197,38 @@ func (mw *msgWriterState) Close() (err error) { } defer mw.writeMu.unlock() + if mw.closed { + return errors.New("writer already closed") + } + mw.closed = true + + if mw.flate { + err = mw.flateWriter.Flush() + if err != nil { + return fmt.Errorf("failed to flush flate: %w", err) + } + } + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { - mw.dict.close() + mw.putFlateWriter() } mw.mu.unlock() return nil } -func (mw *msgWriterState) close() { +func (mw *msgWriter) close() { if mw.c.client { mw.c.writeFrameMu.forceLock() putBufioWriter(mw.c.bw) } mw.writeMu.forceLock() - mw.dict.close() + mw.putFlateWriter() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { @@ -246,7 +248,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if err != nil { return 0, err } - defer c.writeFrameMu.unlock() // If the state says a close has already been written, we wait until // the connection is closed and return that error. @@ -257,17 +258,19 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco wroteClose := c.wroteClose c.closeMu.Unlock() if wroteClose && opcode != opClose { + c.writeFrameMu.unlock() select { case <-ctx.Done(): return 0, ctx.Err() case <-c.closed: - return 0, c.closeErr + return 0, net.ErrClosed } } + defer c.writeFrameMu.unlock() select { case <-c.closed: - return 0, c.closeErr + return 0, net.ErrClosed case c.writeTimeout <- ctx: } @@ -275,9 +278,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if err != nil { select { case <-c.closed: - err = c.closeErr + err = net.ErrClosed case <-ctx.Done(): err = ctx.Err() + default: } c.close(err) err = fmt.Errorf("failed to write frame: %w", err) @@ -321,7 +325,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco select { case <-c.closed: - return n, c.closeErr + if opcode == opClose { + return n, nil + } + return n, net.ErrClosed case c.writeTimeout <- context.Background(): } @@ -367,17 +374,11 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) { return n, nil } -type writerFunc func(p []byte) (int, error) - -func (f writerFunc) Write(p []byte) (int, error) { - return f(p) -} - // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer // and returns it. func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { var writeBuf []byte - bw.Reset(writerFunc(func(p2 []byte) (int, error) { + bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) { writeBuf = p2[:cap(p2)] return len(p2), nil })) diff --git a/ws_js.go b/ws_js.go index b87e32cd..b4011b5c 100644 --- a/ws_js.go +++ b/ws_js.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "reflect" "runtime" @@ -18,13 +19,35 @@ import ( "nhooyr.io/websocket/internal/xsync" ) +// opcode represents a WebSocket opcode. +type opcode int + +// https://tools.ietf.org/html/rfc6455#section-11.8. +const ( + opContinuation opcode = iota + opText + opBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + opClose + opPing + opPong + // 11-16 are reserved for further control frames. +) + // Conn provides a wrapper around the browser WebSocket API. type Conn struct { + noCopy ws wsjs.WebSocket // read limit for a message in bytes. msgReadLimit xsync.Int64 + wg sync.WaitGroup closingMu sync.Mutex isReadClosed xsync.Int64 closeOnce sync.Once @@ -34,6 +57,7 @@ type Conn struct { closeWasClean bool releaseOnClose func() + releaseOnError func() releaseOnMessage func() readSignal chan struct{} @@ -71,9 +95,15 @@ func (c *Conn) init() { c.close(err, e.WasClean) c.releaseOnClose() + c.releaseOnError() c.releaseOnMessage() }) + c.releaseOnError = c.ws.OnError(func(v js.Value) { + c.setCloseErr(errors.New(v.Get("message").String())) + c.closeWithInternal() + }) + c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { c.readBufMu.Lock() defer c.readBufMu.Unlock() @@ -123,7 +153,7 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { return 0, nil, ctx.Err() case <-c.readSignal: case <-c.closed: - return 0, nil, c.closeErr + return 0, nil, net.ErrClosed } c.readBufMu.Lock() @@ -177,7 +207,7 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { if c.isClosed() { - return c.closeErr + return net.ErrClosed } switch typ { case MessageBinary: @@ -194,6 +224,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { + defer c.wg.Wait() err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) @@ -201,19 +232,29 @@ func (c *Conn) Close(code StatusCode, reason string) error { return nil } +// CloseNow closes the WebSocket connection without attempting a close handshake. +// Use when you do not want the overhead of the close handshake. +// +// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close +// a WebSocket without the close handshake. +func (c *Conn) CloseNow() error { + defer c.wg.Wait() + return c.Close(StatusGoingAway, "") +} + func (c *Conn) exportedClose(code StatusCode, reason string) error { c.closingMu.Lock() defer c.closingMu.Unlock() + if c.isClosed() { + return net.ErrClosed + } + ce := fmt.Errorf("sent close: %w", CloseError{ Code: code, Reason: reason, }) - if c.isClosed() { - return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) - } - c.setCloseErr(ce) err := c.ws.Close(int(code), reason) if err != nil { @@ -284,7 +325,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp StatusCode: http.StatusSwitchingProtocols, }, nil case <-c.closed: - return nil, nil, c.closeErr + return nil, nil, net.ErrClosed } } @@ -302,7 +343,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { // It buffers the entire message in memory and then sends it when the writer // is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - return writer{ + return &writer{ c: c, ctx: ctx, typ: typ, @@ -320,7 +361,7 @@ type writer struct { b *bytes.Buffer } -func (w writer) Write(p []byte) (int, error) { +func (w *writer) Write(p []byte) (int, error) { if w.closed { return 0, errors.New("cannot write to closed writer") } @@ -331,7 +372,7 @@ func (w writer) Write(p []byte) (int, error) { return n, nil } -func (w writer) Close() error { +func (w *writer) Close() error { if w.closed { return errors.New("cannot close closed writer") } @@ -350,10 +391,15 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { c.isReadClosed.Store(1) ctx, cancel := context.WithCancel(ctx) + c.wg.Add(1) go func() { + defer c.CloseNow() + defer c.wg.Done() defer cancel() - c.read(ctx) - c.Close(StatusPolicyViolation, "unexpected data message") + _, _, err := c.read(ctx) + if err != nil { + c.Close(StatusPolicyViolation, "unexpected data message") + } }() return ctx } @@ -377,3 +423,168 @@ func (c *Conn) isClosed() bool { return false } } + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + Subprotocols []string + InsecureSkipVerify bool + OriginPatterns []string + CompressionMode CompressionMode + CompressionThreshold int +} + +// Accept is stubbed out for Wasm. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return nil, errors.New("unimplemented") +} + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// +// These are only the status codes defined by the protocol. +// +// You can define custom codes in the 3000-4999 range. +// The 3000-3999 range is reserved for use by libraries, frameworks and applications. +// The 4000-4999 range is reserved for private use. +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + + // 1004 is reserved and so unexported. + statusReserved StatusCode = 1004 + + // StatusNoStatusRcvd cannot be sent in a close message. + // It is reserved for when a close message is received without + // a status code. + StatusNoStatusRcvd StatusCode = 1005 + + // StatusAbnormalClosure is exported for use only with Wasm. + // In non Wasm Go, the returned error will indicate whether the + // connection was closed abnormally. + StatusAbnormalClosure StatusCode = 1006 + + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExtension StatusCode = 1010 + StatusInternalError StatusCode = 1011 + StatusServiceRestart StatusCode = 1012 + StatusTryAgainLater StatusCode = 1013 + StatusBadGateway StatusCode = 1014 + + // StatusTLSHandshake is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether there was + // a TLS handshake failure. + StatusTLSHandshake StatusCode = 1015 +) + +// CloseError is returned when the connection is closed with a status and reason. +// +// Use Go 1.13's errors.As to check for this error. +// Also see the CloseStatus helper. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab +// the status code from a CloseError. +// +// -1 will be returned if the passed error is nil or not a CloseError. +func CloseStatus(err error) StatusCode { + var ce CloseError + if errors.As(err, &ce) { + return ce.Code + } + return -1 +} + +// CompressionMode represents the modes available to the deflate extension. +// See https://tools.ietf.org/html/rfc7692 +// Works in all browsers except Safari which does not implement the deflate extension. +type CompressionMode int + +const ( + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. + // + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than 512 bytes. + CompressionNoContextTakeover CompressionMode = iota + + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this can be very efficient. + // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover + + // CompressionDisabled disables the deflate extension. + // + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. + CompressionDisabled +) + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary +) + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +} + +type noCopy struct{} + +func (*noCopy) Lock() {} diff --git a/ws_js_test.go b/ws_js_test.go index e6be6181..ba98b9a0 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -36,3 +36,19 @@ func TestWasm(t *testing.T) { err = c.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) } + +func TestWasmDialTimeout(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + beforeDial := time.Now() + _, _, err := websocket.Dial(ctx, "ws://example.com:9893", &websocket.DialOptions{ + Subprotocols: []string{"echo"}, + }) + assert.Error(t, err) + if time.Since(beforeDial) >= time.Second { + t.Fatal("wasm context dial timeout is not working", time.Since(beforeDial)) + } +} diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 2000a77a..7c986a0d 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -9,6 +9,7 @@ import ( "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/util" ) // Read reads a JSON message from c into v. @@ -51,17 +52,17 @@ func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to write JSON message") - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - // json.Marshal cannot reuse buffers between calls as it has to return // a copy of the byte slice but Encoder does as it directly writes to w. - err = json.NewEncoder(w).Encode(v) + err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) { + err := c.Write(ctx, websocket.MessageText, p) + if err != nil { + return 0, err + } + return len(p), nil + })).Encode(v) if err != nil { return fmt.Errorf("failed to marshal JSON: %w", err) } - - return w.Close() + return nil } diff --git a/wspb/wspb.go b/wspb/wspb.go deleted file mode 100644 index e43042d5..00000000 --- a/wspb/wspb.go +++ /dev/null @@ -1,73 +0,0 @@ -// Package wspb provides helpers for reading and writing protobuf messages. -package wspb // import "nhooyr.io/websocket/wspb" - -import ( - "bytes" - "context" - "fmt" - - "github.com/golang/protobuf/proto" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bpool" - "nhooyr.io/websocket/internal/errd" -) - -// Read reads a protobuf message from c into v. -// It will reuse buffers in between calls to avoid allocations. -func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - return read(ctx, c, v) -} - -func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { - defer errd.Wrap(&err, "failed to read protobuf message") - - typ, r, err := c.Reader(ctx) - if err != nil { - return err - } - - if typ != websocket.MessageBinary { - c.Close(websocket.StatusUnsupportedData, "expected binary message") - return fmt.Errorf("expected binary message for protobuf but got: %v", typ) - } - - b := bpool.Get() - defer bpool.Put(b) - - _, err = b.ReadFrom(r) - if err != nil { - return err - } - - err = proto.Unmarshal(b.Bytes(), v) - if err != nil { - c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") - return fmt.Errorf("failed to unmarshal protobuf: %w", err) - } - - return nil -} - -// Write writes the protobuf message v to c. -// It will reuse buffers in between calls to avoid allocations. -func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - return write(ctx, c, v) -} - -func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { - defer errd.Wrap(&err, "failed to write protobuf message") - - b := bpool.Get() - pb := proto.NewBuffer(b.Bytes()) - defer func() { - bpool.Put(bytes.NewBuffer(pb.Bytes())) - }() - - err = pb.Marshal(v) - if err != nil { - return fmt.Errorf("failed to marshal protobuf: %w", err) - } - - return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) -}