Skip to content

Commit

Permalink
Ensure no goroutines leak after Close
Browse files Browse the repository at this point in the history
Closes #330
  • Loading branch information
nhooyr committed Oct 19, 2023
1 parent e6a7e0e commit 6ed989a
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 31 deletions.
34 changes: 24 additions & 10 deletions conn.go
Expand Up @@ -53,8 +53,10 @@ type Conn struct {
br *bufio.Reader
bw *bufio.Writer

readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopCancel context.CancelFunc
timeoutLoopDone chan struct{}
readTimeout chan context.Context
writeTimeout chan context.Context

// Read state.
readMu *mu
Expand Down Expand Up @@ -102,8 +104,9 @@ func newConn(cfg connConfig) *Conn {
br: cfg.br,
bw: cfg.bw,

readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),

closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
Expand All @@ -130,7 +133,9 @@ func newConn(cfg connConfig) *Conn {
c.close(errors.New("connection garbage collected"))
})

go c.timeoutLoop()
var ctx context.Context
ctx, c.timeoutLoopCancel = context.WithCancel(context.Background())
go c.timeoutLoop(ctx)

return c
}
Expand All @@ -152,6 +157,10 @@ func (c *Conn) close(err error) {
err = c.rwc.Close()
}
c.setCloseErrLocked(err)

c.timeoutLoopCancel()
<-c.timeoutLoopDone

close(c.closed)
runtime.SetFinalizer(c, nil)

Expand All @@ -160,18 +169,23 @@ func (c *Conn) close(err error) {
// closeErr.
c.rwc.Close()

go func() {
c.msgWriter.close()
c.msgReader.close()
}()
c.closeMu.Unlock()
defer c.closeMu.Lock()

c.msgWriter.close()
c.msgReader.close()
}

func (c *Conn) timeoutLoop() {
func (c *Conn) timeoutLoop(ctx context.Context) {
defer close(c.timeoutLoopDone)

readCtx := context.Background()
writeCtx := context.Background()

for {
select {
case <-ctx.Done():
return
case <-c.closed:
return

Expand Down
17 changes: 9 additions & 8 deletions conn_test.go
Expand Up @@ -399,10 +399,8 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs
c1, c2 = c2, c1
}
t.Cleanup(func() {
// We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid
// blocking the test shutting down.
go c2.Close(websocket.StatusInternalError, "")
go c1.Close(websocket.StatusInternalError, "")
c2.CloseNow()
c1.CloseNow()
})

return tt, c1, c2
Expand Down Expand Up @@ -596,16 +594,19 @@ func TestConcurrentClosePing(t *testing.T) {
defer c2.CloseNow()
c1.CloseRead(context.Background())
c2.CloseRead(context.Background())
go func() {
errc := xsync.Go(func() error {
for range time.Tick(time.Millisecond) {
if err := c1.Ping(context.Background()); err != nil {
return
err := c1.Ping(context.Background())
if err != nil {
return err
}
}
}()
panic("unreachable")
})

time.Sleep(10 * time.Millisecond)
assert.Success(t, c1.Close(websocket.StatusNormalClosure, ""))
<-errc
}()
}
}
3 changes: 2 additions & 1 deletion dial_test.go
Expand Up @@ -164,11 +164,12 @@ func Test_verifyHostOverride(t *testing.T) {
}, nil
}

_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
Host: tc.host,
})
assert.Success(t, err)
c.CloseNow()
})
}

Expand Down
15 changes: 14 additions & 1 deletion main_test.go
Expand Up @@ -7,10 +7,23 @@ import (
"testing"
)

func goroutineStacks() []byte {
buf := make([]byte, 512)
for {
m := runtime.Stack(buf, true)
if m < len(buf) {
return buf[:m]
}
buf = make([]byte, len(buf)*2)
}
}

func TestMain(m *testing.M) {
code := m.Run()
if runtime.NumGoroutine() != 1 {
if runtime.GOOS != "js" && runtime.NumGoroutine() != 1 ||
runtime.GOOS == "js" && runtime.NumGoroutine() != 2 {
fmt.Fprintf(os.Stderr, "goroutine leak detected, expected 1 but got %d goroutines\n", runtime.NumGoroutine())
fmt.Fprintf(os.Stderr, "%s\n", goroutineStacks())
os.Exit(1)
}
os.Exit(code)
Expand Down
5 changes: 5 additions & 0 deletions read.go
Expand Up @@ -219,6 +219,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
case <-ctx.Done():
return header{}, ctx.Err()
default:
c.readMu.unlock()
c.close(err)
return header{}, err
}
Expand Down Expand Up @@ -249,6 +250,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
return n, ctx.Err()
default:
err = fmt.Errorf("failed to read frame payload: %w", err)
c.readMu.unlock()
c.close(err)
return n, err
}
Expand Down Expand Up @@ -319,6 +321,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
err = fmt.Errorf("received close frame: %w", ce)
c.setCloseErr(err)
c.writeClose(ce.Code, ce.Reason)
c.readMu.unlock()
c.close(err)
return err
}
Expand All @@ -334,6 +337,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro

if !c.msgReader.fin {
err = errors.New("previous message not read to completion")
c.readMu.unlock()
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
}
Expand Down Expand Up @@ -409,6 +413,7 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.readMu.unlock()
mr.c.close(err)
}
return n, err
Expand Down
25 changes: 15 additions & 10 deletions write.go
Expand Up @@ -109,7 +109,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error

if !c.flate() {
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
return c.writeFrame(true, ctx, true, false, c.msgWriter.opcode, p)
}

n, err := mw.Write(p)
Expand Down Expand Up @@ -159,6 +159,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
mw.writeMu.unlock()
mw.c.close(err)
}
}()
Expand All @@ -179,7 +180,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
}

func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
n, err := mw.c.writeFrame(true, mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
Expand All @@ -191,25 +192,25 @@ func (mw *msgWriter) write(p []byte) (int, error) {
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")

if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true

err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()

if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true

if mw.flate {
err = mw.flateWriter.Flush()
if err != nil {
return fmt.Errorf("failed to flush flate: %w", err)
}
}

_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
_, err = mw.c.writeFrame(true, mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
Expand All @@ -235,15 +236,15 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()

_, err := c.writeFrame(ctx, true, false, opcode, p)
_, err := c.writeFrame(false, ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
return nil
}

// frame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
func (c *Conn) writeFrame(msgWriter bool, ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {

Check failure on line 247 in write.go

View workflow job for this annotation

GitHub Actions / lint

context.Context should be the first parameter of a function

Check failure on line 247 in write.go

View workflow job for this annotation

GitHub Actions / lint

context.Context should be the first parameter of a function
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
Expand Down Expand Up @@ -283,6 +284,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
err = ctx.Err()
default:
}
c.writeFrameMu.unlock()
if msgWriter {
c.msgWriter.writeMu.unlock()
}
c.close(err)
err = fmt.Errorf("failed to write frame: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion ws_js.go
Expand Up @@ -231,7 +231,7 @@ func (c *Conn) Close(code StatusCode, reason string) error {
}

// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use When you do not want the overhead of the close handshake.
// Use when you do not want the overhead of the close handshake.
//
// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
// a WebSocket without the close handshake.
Expand Down

0 comments on commit 6ed989a

Please sign in to comment.