Skip to content

Commit

Permalink
Ensure no goroutines leak after Close in a cleaner way
Browse files Browse the repository at this point in the history
Closes #330
  • Loading branch information
nhooyr committed Oct 19, 2023
1 parent 6ed989a commit d7a55cf
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 44 deletions.
4 changes: 3 additions & 1 deletion close.go
Expand Up @@ -99,12 +99,14 @@ func CloseStatus(err error) StatusCode {
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) error {
defer c.wgWait()
return c.closeHandshake(code, reason)
}

// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use When you do not want the overhead of the close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer c.wgWait()
defer errd.Wrap(&err, "failed to close WebSocket")

if c.isClosed() {
Expand Down
51 changes: 27 additions & 24 deletions conn.go
Expand Up @@ -45,6 +45,8 @@ const (
type Conn struct {
noCopy

wg sync.WaitGroup

subprotocol string
rwc io.ReadWriteCloser
client bool
Expand All @@ -53,10 +55,8 @@ type Conn struct {
br *bufio.Reader
bw *bufio.Writer

timeoutLoopCancel context.CancelFunc
timeoutLoopDone chan struct{}
readTimeout chan context.Context
writeTimeout chan context.Context
readTimeout chan context.Context
writeTimeout chan context.Context

// Read state.
readMu *mu
Expand Down Expand Up @@ -104,9 +104,8 @@ func newConn(cfg connConfig) *Conn {
br: cfg.br,
bw: cfg.bw,

timeoutLoopDone: make(chan struct{}),
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),

closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
Expand All @@ -133,9 +132,7 @@ func newConn(cfg connConfig) *Conn {
c.close(errors.New("connection garbage collected"))
})

var ctx context.Context
ctx, c.timeoutLoopCancel = context.WithCancel(context.Background())
go c.timeoutLoop(ctx)
c.wgGo(c.timeoutLoop)

return c
}
Expand All @@ -158,9 +155,6 @@ func (c *Conn) close(err error) {
}
c.setCloseErrLocked(err)

c.timeoutLoopCancel()
<-c.timeoutLoopDone

close(c.closed)
runtime.SetFinalizer(c, nil)

Expand All @@ -169,23 +163,18 @@ func (c *Conn) close(err error) {
// closeErr.
c.rwc.Close()

c.closeMu.Unlock()
defer c.closeMu.Lock()

c.msgWriter.close()
c.msgReader.close()
c.wgGo(func() {
c.msgWriter.close()
c.msgReader.close()
})
}

func (c *Conn) timeoutLoop(ctx context.Context) {
defer close(c.timeoutLoopDone)

func (c *Conn) timeoutLoop() {
readCtx := context.Background()
writeCtx := context.Background()

for {
select {
case <-ctx.Done():
return
case <-c.closed:
return

Expand All @@ -194,7 +183,9 @@ func (c *Conn) timeoutLoop(ctx context.Context) {

case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
c.wgGo(func() {
c.writeError(StatusPolicyViolation, errors.New("read timed out"))
})
case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
return
Expand Down Expand Up @@ -311,3 +302,15 @@ func (m *mu) unlock() {
type noCopy struct{}

func (*noCopy) Lock() {}

func (c *Conn) wgGo(fn func()) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
fn()
}()
}

func (c *Conn) wgWait() {
c.wg.Wait()
}
8 changes: 3 additions & 5 deletions read.go
Expand Up @@ -62,8 +62,11 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
// frames are responded to. This means c.Ping and c.Close will still work as expected.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)

c.wg.Add(1)
go func() {
defer c.CloseNow()
defer c.wg.Done()
defer cancel()
_, _, err := c.Reader(ctx)
if err == nil {
Expand Down Expand Up @@ -219,7 +222,6 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
case <-ctx.Done():
return header{}, ctx.Err()
default:
c.readMu.unlock()
c.close(err)
return header{}, err
}
Expand Down Expand Up @@ -250,7 +252,6 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
return n, ctx.Err()
default:
err = fmt.Errorf("failed to read frame payload: %w", err)
c.readMu.unlock()
c.close(err)
return n, err
}
Expand Down Expand Up @@ -321,7 +322,6 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
err = fmt.Errorf("received close frame: %w", ce)
c.setCloseErr(err)
c.writeClose(ce.Code, ce.Reason)
c.readMu.unlock()
c.close(err)
return err
}
Expand All @@ -337,7 +337,6 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro

if !c.msgReader.fin {
err = errors.New("previous message not read to completion")
c.readMu.unlock()
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
}
Expand Down Expand Up @@ -413,7 +412,6 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.readMu.unlock()
mr.c.close(err)
}
return n, err
Expand Down
23 changes: 9 additions & 14 deletions write.go
Expand Up @@ -109,7 +109,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error

if !c.flate() {
defer c.msgWriter.mu.unlock()
return c.writeFrame(true, ctx, true, false, c.msgWriter.opcode, p)
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}

n, err := mw.Write(p)
Expand Down Expand Up @@ -146,20 +146,19 @@ func (mw *msgWriter) putFlateWriter() {

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
if mw.closed {
return 0, errors.New("cannot use closed writer")
}

err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()

if mw.closed {
return 0, errors.New("cannot use closed writer")
}

defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
mw.writeMu.unlock()
mw.c.close(err)
}
}()
Expand All @@ -180,7 +179,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
}

func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(true, mw.ctx, false, mw.flate, mw.opcode, p)
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
Expand Down Expand Up @@ -210,7 +209,7 @@ func (mw *msgWriter) Close() (err error) {
}
}

_, err = mw.c.writeFrame(true, mw.ctx, true, mw.flate, mw.opcode, nil)
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
Expand All @@ -236,15 +235,15 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()

_, err := c.writeFrame(false, ctx, true, false, opcode, p)
_, err := c.writeFrame(ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
return nil
}

// frame handles all writes to the connection.
func (c *Conn) writeFrame(msgWriter bool, ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
Expand Down Expand Up @@ -284,10 +283,6 @@ func (c *Conn) writeFrame(msgWriter bool, ctx context.Context, fin bool, flate b
err = ctx.Err()
default:
}
c.writeFrameMu.unlock()
if msgWriter {
c.msgWriter.writeMu.unlock()
}
c.close(err)
err = fmt.Errorf("failed to write frame: %w", err)
}
Expand Down

0 comments on commit d7a55cf

Please sign in to comment.