diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 6961e5c8..00000000 --- a/.gitignore +++ /dev/null @@ -1 +0,0 @@ -websocket.test diff --git a/ci/bench.sh b/ci/bench.sh index 8f99278d..a553b93a 100755 --- a/ci/bench.sh +++ b/ci/bench.sh @@ -2,8 +2,8 @@ set -eu cd -- "$(dirname "$0")/.." -go test --run=^$ --bench=. "$@" ./... +go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" . ( cd ./internal/thirdparty - go test --run=^$ --bench=. "$@" ./... + go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" . ) diff --git a/conn.go b/conn.go index 81a57c7f..78eaad82 100644 --- a/conn.go +++ b/conn.go @@ -63,7 +63,7 @@ type Conn struct { readCloseFrameErr error // Write state. - msgWriterState *msgWriterState + msgWriter *msgWriter writeFrameMu *mu writeBuf []byte writeHeaderBuf [8]byte @@ -113,14 +113,14 @@ func newConn(cfg connConfig) *Conn { c.msgReader = newMsgReader(c) - c.msgWriterState = newMsgWriterState(c) + c.msgWriter = newMsgWriter(c) if c.client { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 128 - if !c.msgWriterState.flateContextTakeover() { + if !c.msgWriter.flateContextTakeover() { c.flateThreshold = 512 } } @@ -157,8 +157,7 @@ func (c *Conn) close(err error) { c.rwc.Close() go func() { - c.msgWriterState.close() - + c.msgWriter.close() c.msgReader.close() }() } diff --git a/write.go b/write.go index 500609dd..20a71d3e 100644 --- a/write.go +++ b/write.go @@ -49,30 +49,11 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { } type msgWriter struct { - mw *msgWriterState - closed bool -} - -func (mw *msgWriter) Write(p []byte) (int, error) { - if mw.closed { - return 0, errors.New("cannot use closed writer") - } - return mw.mw.Write(p) -} - -func (mw *msgWriter) Close() error { - if mw.closed { - return errors.New("cannot use closed writer") - } - mw.closed = true - return mw.mw.Close() -} - -type msgWriterState struct { c *Conn mu *mu writeMu *mu + closed bool ctx context.Context opcode opcode @@ -82,8 +63,8 @@ type msgWriterState struct { flateWriter *flate.Writer } -func newMsgWriterState(c *Conn) *msgWriterState { - mw := &msgWriterState{ +func newMsgWriter(c *Conn) *msgWriter { + mw := &msgWriter{ c: c, mu: newMu(c), writeMu: newMu(c), @@ -91,7 +72,7 @@ func newMsgWriterState(c *Conn) *msgWriterState { return mw } -func (mw *msgWriterState) ensureFlate() { +func (mw *msgWriter) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ w: util.WriterFunc(mw.write), @@ -104,7 +85,7 @@ func (mw *msgWriterState) ensureFlate() { mw.flate = true } -func (mw *msgWriterState) flateContextTakeover() bool { +func (mw *msgWriter) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } @@ -112,14 +93,11 @@ func (mw *msgWriterState) flateContextTakeover() bool { } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.msgWriterState.reset(ctx, typ) + err := c.msgWriter.reset(ctx, typ) if err != nil { return nil, err } - return &msgWriter{ - mw: c.msgWriterState, - closed: false, - }, nil + return c.msgWriter, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { @@ -129,8 +107,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error } if !c.flate() { - defer c.msgWriterState.mu.unlock() - return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) + defer c.msgWriter.mu.unlock() + return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } n, err := mw.Write(p) @@ -142,7 +120,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return n, err } -func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { +func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.lock(ctx) if err != nil { return err @@ -151,13 +129,14 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false + mw.closed = false mw.trimWriter.reset() return nil } -func (mw *msgWriterState) putFlateWriter() { +func (mw *msgWriter) putFlateWriter() { if mw.flateWriter != nil { putFlateWriter(mw.flateWriter) mw.flateWriter = nil @@ -165,7 +144,11 @@ func (mw *msgWriterState) putFlateWriter() { } // Write writes the given bytes to the WebSocket connection. -func (mw *msgWriterState) Write(p []byte) (_ int, err error) { +func (mw *msgWriter) Write(p []byte) (_ int, err error) { + if mw.closed { + return 0, errors.New("cannot use closed writer") + } + err = mw.writeMu.lock(mw.ctx) if err != nil { return 0, fmt.Errorf("failed to write: %w", err) @@ -194,7 +177,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) { return mw.write(p) } -func (mw *msgWriterState) write(p []byte) (int, error) { +func (mw *msgWriter) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) @@ -204,9 +187,14 @@ func (mw *msgWriterState) write(p []byte) (int, error) { } // Close flushes the frame to the connection. -func (mw *msgWriterState) Close() (err error) { +func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") + if mw.closed { + return errors.New("writer already closed") + } + mw.closed = true + err = mw.writeMu.lock(mw.ctx) if err != nil { return err @@ -232,7 +220,7 @@ func (mw *msgWriterState) Close() (err error) { return nil } -func (mw *msgWriterState) close() { +func (mw *msgWriter) close() { if mw.c.client { mw.c.writeFrameMu.forceLock() putBufioWriter(mw.c.bw)