diff --git a/close.go b/close.go index 0abc864f..1053751c 100644 --- a/close.go +++ b/close.go @@ -99,12 +99,14 @@ func CloseStatus(err error) StatusCode { // Close will unblock all goroutines interacting with the connection once // complete. func (c *Conn) Close(code StatusCode, reason string) error { + defer c.wgWait() return c.closeHandshake(code, reason) } // CloseNow closes the WebSocket connection without attempting a close handshake. -// Use When you do not want the overhead of the close handshake. +// Use when you do not want the overhead of the close handshake. func (c *Conn) CloseNow() (err error) { + defer c.wgWait() defer errd.Wrap(&err, "failed to close WebSocket") if c.isClosed() { diff --git a/conn.go b/conn.go index 5084dce1..05531c3b 100644 --- a/conn.go +++ b/conn.go @@ -45,6 +45,8 @@ const ( type Conn struct { noCopy + wg sync.WaitGroup + subprotocol string rwc io.ReadWriteCloser client bool @@ -53,10 +55,8 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - timeoutLoopCancel context.CancelFunc - timeoutLoopDone chan struct{} - readTimeout chan context.Context - writeTimeout chan context.Context + readTimeout chan context.Context + writeTimeout chan context.Context // Read state. readMu *mu @@ -104,9 +104,8 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - timeoutLoopDone: make(chan struct{}), - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), @@ -133,9 +132,7 @@ func newConn(cfg connConfig) *Conn { c.close(errors.New("connection garbage collected")) }) - var ctx context.Context - ctx, c.timeoutLoopCancel = context.WithCancel(context.Background()) - go c.timeoutLoop(ctx) + c.wgGo(c.timeoutLoop) return c } @@ -158,9 +155,6 @@ func (c *Conn) close(err error) { } c.setCloseErrLocked(err) - c.timeoutLoopCancel() - <-c.timeoutLoopDone - close(c.closed) runtime.SetFinalizer(c, nil) @@ -169,23 +163,18 @@ func (c *Conn) close(err error) { // closeErr. c.rwc.Close() - c.closeMu.Unlock() - defer c.closeMu.Lock() - - c.msgWriter.close() - c.msgReader.close() + c.wgGo(func() { + c.msgWriter.close() + c.msgReader.close() + }) } -func (c *Conn) timeoutLoop(ctx context.Context) { - defer close(c.timeoutLoopDone) - +func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() for { select { - case <-ctx.Done(): - return case <-c.closed: return @@ -194,7 +183,9 @@ func (c *Conn) timeoutLoop(ctx context.Context) { case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - go c.writeError(StatusPolicyViolation, errors.New("timed out")) + c.wgGo(func() { + c.writeError(StatusPolicyViolation, errors.New("read timed out")) + }) case <-writeCtx.Done(): c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return @@ -311,3 +302,15 @@ func (m *mu) unlock() { type noCopy struct{} func (*noCopy) Lock() {} + +func (c *Conn) wgGo(fn func()) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + fn() + }() +} + +func (c *Conn) wgWait() { + c.wg.Wait() +} diff --git a/read.go b/read.go index 5c180fba..8742842e 100644 --- a/read.go +++ b/read.go @@ -62,8 +62,11 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // frames are responded to. This means c.Ping and c.Close will still work as expected. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) + + c.wg.Add(1) go func() { defer c.CloseNow() + defer c.wg.Done() defer cancel() _, _, err := c.Reader(ctx) if err == nil { @@ -219,7 +222,6 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case <-ctx.Done(): return header{}, ctx.Err() default: - c.readMu.unlock() c.close(err) return header{}, err } @@ -250,7 +252,6 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) - c.readMu.unlock() c.close(err) return n, err } @@ -321,7 +322,6 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) - c.readMu.unlock() c.close(err) return err } @@ -337,7 +337,6 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro if !c.msgReader.fin { err = errors.New("previous message not read to completion") - c.readMu.unlock() c.close(fmt.Errorf("failed to get reader: %w", err)) return 0, nil, err } @@ -413,7 +412,6 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { } if err != nil { err = fmt.Errorf("failed to read: %w", err) - mr.c.readMu.unlock() mr.c.close(err) } return n, err diff --git a/write.go b/write.go index 0fbfd9cd..7b1152ce 100644 --- a/write.go +++ b/write.go @@ -109,7 +109,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error if !c.flate() { defer c.msgWriter.mu.unlock() - return c.writeFrame(true, ctx, true, false, c.msgWriter.opcode, p) + return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } n, err := mw.Write(p) @@ -146,20 +146,19 @@ func (mw *msgWriter) putFlateWriter() { // Write writes the given bytes to the WebSocket connection. func (mw *msgWriter) Write(p []byte) (_ int, err error) { - if mw.closed { - return 0, errors.New("cannot use closed writer") - } - err = mw.writeMu.lock(mw.ctx) if err != nil { return 0, fmt.Errorf("failed to write: %w", err) } defer mw.writeMu.unlock() + if mw.closed { + return 0, errors.New("cannot use closed writer") + } + defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) - mw.writeMu.unlock() mw.c.close(err) } }() @@ -180,7 +179,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { } func (mw *msgWriter) write(p []byte) (int, error) { - n, err := mw.c.writeFrame(true, mw.ctx, false, mw.flate, mw.opcode, p) + n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) } @@ -210,7 +209,7 @@ func (mw *msgWriter) Close() (err error) { } } - _, err = mw.c.writeFrame(true, mw.ctx, true, mw.flate, mw.opcode, nil) + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } @@ -236,7 +235,7 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - _, err := c.writeFrame(false, ctx, true, false, opcode, p) + _, err := c.writeFrame(ctx, true, false, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } @@ -244,7 +243,7 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } // frame handles all writes to the connection. -func (c *Conn) writeFrame(msgWriter bool, ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { +func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err @@ -284,10 +283,6 @@ func (c *Conn) writeFrame(msgWriter bool, ctx context.Context, fin bool, flate b err = ctx.Err() default: } - c.writeFrameMu.unlock() - if msgWriter { - c.msgWriter.writeMu.unlock() - } c.close(err) err = fmt.Errorf("failed to write frame: %w", err) }