Skip to content

Commit

Permalink
write: Zero alloc writes with Writer
Browse files Browse the repository at this point in the history
Closes #354
  • Loading branch information
nhooyr committed Oct 14, 2023
1 parent a975390 commit 1dbc141
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 45 deletions.
1 change: 0 additions & 1 deletion .gitignore

This file was deleted.

4 changes: 2 additions & 2 deletions ci/bench.sh
Expand Up @@ -2,8 +2,8 @@
set -eu
cd -- "$(dirname "$0")/.."

go test --run=^$ --bench=. "$@" ./...
go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" .
(
cd ./internal/thirdparty
go test --run=^$ --bench=. "$@" ./...
go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" .
)
9 changes: 4 additions & 5 deletions conn.go
Expand Up @@ -63,7 +63,7 @@ type Conn struct {
readCloseFrameErr error

// Write state.
msgWriterState *msgWriterState
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
Expand Down Expand Up @@ -113,14 +113,14 @@ func newConn(cfg connConfig) *Conn {

c.msgReader = newMsgReader(c)

c.msgWriterState = newMsgWriterState(c)
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}

if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriterState.flateContextTakeover() {
if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512
}
}
Expand Down Expand Up @@ -157,8 +157,7 @@ func (c *Conn) close(err error) {
c.rwc.Close()

go func() {
c.msgWriterState.close()

c.msgWriter.close()
c.msgReader.close()
}()
}
Expand Down
62 changes: 25 additions & 37 deletions write.go
Expand Up @@ -49,30 +49,11 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
}

type msgWriter struct {
mw *msgWriterState
closed bool
}

func (mw *msgWriter) Write(p []byte) (int, error) {
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
return mw.mw.Write(p)
}

func (mw *msgWriter) Close() error {
if mw.closed {
return errors.New("cannot use closed writer")
}
mw.closed = true
return mw.mw.Close()
}

type msgWriterState struct {
c *Conn

mu *mu
writeMu *mu
closed bool

ctx context.Context
opcode opcode
Expand All @@ -82,16 +63,16 @@ type msgWriterState struct {
flateWriter *flate.Writer
}

func newMsgWriterState(c *Conn) *msgWriterState {
mw := &msgWriterState{
func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriter{
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}

func (mw *msgWriterState) ensureFlate() {
func (mw *msgWriter) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: util.WriterFunc(mw.write),
Expand All @@ -104,22 +85,19 @@ func (mw *msgWriterState) ensureFlate() {
mw.flate = true
}

func (mw *msgWriterState) flateContextTakeover() bool {
func (mw *msgWriter) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
}
return !mw.c.copts.serverNoContextTakeover
}

func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriterState.reset(ctx, typ)
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return nil, err
}
return &msgWriter{
mw: c.msgWriterState,
closed: false,
}, nil
return c.msgWriter, nil
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
Expand All @@ -129,8 +107,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
}

if !c.flate() {
defer c.msgWriterState.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p)
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}

n, err := mw.Write(p)
Expand All @@ -142,7 +120,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
return n, err
}

func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx)
if err != nil {
return err
Expand All @@ -151,21 +129,26 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
mw.closed = false

mw.trimWriter.reset()

return nil
}

func (mw *msgWriterState) putFlateWriter() {
func (mw *msgWriter) putFlateWriter() {
if mw.flateWriter != nil {
putFlateWriter(mw.flateWriter)
mw.flateWriter = nil
}
}

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
if mw.closed {
return 0, errors.New("cannot use closed writer")
}

err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
Expand Down Expand Up @@ -194,7 +177,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
return mw.write(p)
}

func (mw *msgWriterState) write(p []byte) (int, error) {
func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
Expand All @@ -204,9 +187,14 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
}

// Close flushes the frame to the connection.
func (mw *msgWriterState) Close() (err error) {
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")

if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true

err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
Expand All @@ -232,7 +220,7 @@ func (mw *msgWriterState) Close() (err error) {
return nil
}

func (mw *msgWriterState) close() {
func (mw *msgWriter) close() {
if mw.c.client {
mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw)
Expand Down

0 comments on commit 1dbc141

Please sign in to comment.