From 6ed989afc10be2cf8139362ca006cad4a1cb98d8 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 19 Oct 2023 02:51:36 -0700 Subject: [PATCH] Ensure no goroutines leak after Close Closes #330 --- conn.go | 34 ++++++++++++++++++++++++---------- conn_test.go | 17 +++++++++-------- dial_test.go | 3 ++- main_test.go | 15 ++++++++++++++- read.go | 5 +++++ write.go | 25 +++++++++++++++---------- ws_js.go | 2 +- 7 files changed, 70 insertions(+), 31 deletions(-) diff --git a/conn.go b/conn.go index 3b3a9f98..5084dce1 100644 --- a/conn.go +++ b/conn.go @@ -53,8 +53,10 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context + timeoutLoopCancel context.CancelFunc + timeoutLoopDone chan struct{} + readTimeout chan context.Context + writeTimeout chan context.Context // Read state. readMu *mu @@ -102,8 +104,9 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), + timeoutLoopDone: make(chan struct{}), + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), @@ -130,7 +133,9 @@ func newConn(cfg connConfig) *Conn { c.close(errors.New("connection garbage collected")) }) - go c.timeoutLoop() + var ctx context.Context + ctx, c.timeoutLoopCancel = context.WithCancel(context.Background()) + go c.timeoutLoop(ctx) return c } @@ -152,6 +157,10 @@ func (c *Conn) close(err error) { err = c.rwc.Close() } c.setCloseErrLocked(err) + + c.timeoutLoopCancel() + <-c.timeoutLoopDone + close(c.closed) runtime.SetFinalizer(c, nil) @@ -160,18 +169,23 @@ func (c *Conn) close(err error) { // closeErr. c.rwc.Close() - go func() { - c.msgWriter.close() - c.msgReader.close() - }() + c.closeMu.Unlock() + defer c.closeMu.Lock() + + c.msgWriter.close() + c.msgReader.close() } -func (c *Conn) timeoutLoop() { +func (c *Conn) timeoutLoop(ctx context.Context) { + defer close(c.timeoutLoopDone) + readCtx := context.Background() writeCtx := context.Background() for { select { + case <-ctx.Done(): + return case <-c.closed: return diff --git a/conn_test.go b/conn_test.go index 17c52c32..97b172dc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -399,10 +399,8 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs c1, c2 = c2, c1 } t.Cleanup(func() { - // We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid - // blocking the test shutting down. - go c2.Close(websocket.StatusInternalError, "") - go c1.Close(websocket.StatusInternalError, "") + c2.CloseNow() + c1.CloseNow() }) return tt, c1, c2 @@ -596,16 +594,19 @@ func TestConcurrentClosePing(t *testing.T) { defer c2.CloseNow() c1.CloseRead(context.Background()) c2.CloseRead(context.Background()) - go func() { + errc := xsync.Go(func() error { for range time.Tick(time.Millisecond) { - if err := c1.Ping(context.Background()); err != nil { - return + err := c1.Ping(context.Background()) + if err != nil { + return err } } - }() + panic("unreachable") + }) time.Sleep(10 * time.Millisecond) assert.Success(t, c1.Close(websocket.StatusNormalClosure, "")) + <-errc }() } } diff --git a/dial_test.go b/dial_test.go index 63cb4be6..237a2874 100644 --- a/dial_test.go +++ b/dial_test.go @@ -164,11 +164,12 @@ func Test_verifyHostOverride(t *testing.T) { }, nil } - _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ + c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(rt), Host: tc.host, }) assert.Success(t, err) + c.CloseNow() }) } diff --git a/main_test.go b/main_test.go index 336be71c..2b93bb18 100644 --- a/main_test.go +++ b/main_test.go @@ -7,10 +7,23 @@ import ( "testing" ) +func goroutineStacks() []byte { + buf := make([]byte, 512) + for { + m := runtime.Stack(buf, true) + if m < len(buf) { + return buf[:m] + } + buf = make([]byte, len(buf)*2) + } +} + func TestMain(m *testing.M) { code := m.Run() - if runtime.NumGoroutine() != 1 { + if runtime.GOOS != "js" && runtime.NumGoroutine() != 1 || + runtime.GOOS == "js" && runtime.NumGoroutine() != 2 { fmt.Fprintf(os.Stderr, "goroutine leak detected, expected 1 but got %d goroutines\n", runtime.NumGoroutine()) + fmt.Fprintf(os.Stderr, "%s\n", goroutineStacks()) os.Exit(1) } os.Exit(code) diff --git a/read.go b/read.go index 9ab28812..5c180fba 100644 --- a/read.go +++ b/read.go @@ -219,6 +219,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case <-ctx.Done(): return header{}, ctx.Err() default: + c.readMu.unlock() c.close(err) return header{}, err } @@ -249,6 +250,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) + c.readMu.unlock() c.close(err) return n, err } @@ -319,6 +321,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) + c.readMu.unlock() c.close(err) return err } @@ -334,6 +337,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro if !c.msgReader.fin { err = errors.New("previous message not read to completion") + c.readMu.unlock() c.close(fmt.Errorf("failed to get reader: %w", err)) return 0, nil, err } @@ -409,6 +413,7 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { } if err != nil { err = fmt.Errorf("failed to read: %w", err) + mr.c.readMu.unlock() mr.c.close(err) } return n, err diff --git a/write.go b/write.go index 3d062656..0fbfd9cd 100644 --- a/write.go +++ b/write.go @@ -109,7 +109,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error if !c.flate() { defer c.msgWriter.mu.unlock() - return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) + return c.writeFrame(true, ctx, true, false, c.msgWriter.opcode, p) } n, err := mw.Write(p) @@ -159,6 +159,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) + mw.writeMu.unlock() mw.c.close(err) } }() @@ -179,7 +180,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { } func (mw *msgWriter) write(p []byte) (int, error) { - n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) + n, err := mw.c.writeFrame(true, mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) } @@ -191,17 +192,17 @@ func (mw *msgWriter) write(p []byte) (int, error) { func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") - if mw.closed { - return errors.New("writer already closed") - } - mw.closed = true - err = mw.writeMu.lock(mw.ctx) if err != nil { return err } defer mw.writeMu.unlock() + if mw.closed { + return errors.New("writer already closed") + } + mw.closed = true + if mw.flate { err = mw.flateWriter.Flush() if err != nil { @@ -209,7 +210,7 @@ func (mw *msgWriter) Close() (err error) { } } - _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) + _, err = mw.c.writeFrame(true, mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } @@ -235,7 +236,7 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - _, err := c.writeFrame(ctx, true, false, opcode, p) + _, err := c.writeFrame(false, ctx, true, false, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } @@ -243,7 +244,7 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } // frame handles all writes to the connection. -func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { +func (c *Conn) writeFrame(msgWriter bool, ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err @@ -283,6 +284,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco err = ctx.Err() default: } + c.writeFrameMu.unlock() + if msgWriter { + c.msgWriter.writeMu.unlock() + } c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } diff --git a/ws_js.go b/ws_js.go index cae68bb6..180d0564 100644 --- a/ws_js.go +++ b/ws_js.go @@ -231,7 +231,7 @@ func (c *Conn) Close(code StatusCode, reason string) error { } // CloseNow closes the WebSocket connection without attempting a close handshake. -// Use When you do not want the overhead of the close handshake. +// Use when you do not want the overhead of the close handshake. // // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close // a WebSocket without the close handshake.