Skip to content

Commit

Permalink
wsjs: Ensure no goroutines leak after Close
Browse files Browse the repository at this point in the history
Closes #330
  • Loading branch information
nhooyr committed Oct 19, 2023
1 parent 7b1a6bb commit d91a212
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
4 changes: 2 additions & 2 deletions close.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ func CloseStatus(err error) StatusCode {
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) error {
defer c.wgWait()
defer c.wg.Wait()
return c.closeHandshake(code, reason)
}

// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer c.wgWait()
defer c.wg.Wait()
defer errd.Wrap(&err, "failed to close WebSocket")

if c.isClosed() {
Expand Down
33 changes: 14 additions & 19 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ const (
type Conn struct {
noCopy

wg sync.WaitGroup

subprotocol string
rwc io.ReadWriteCloser
client bool
Expand All @@ -72,6 +70,7 @@ type Conn struct {
writeHeaderBuf [8]byte
writeHeader header

wg sync.WaitGroup
closed chan struct{}
closeMu sync.Mutex
closeErr error
Expand Down Expand Up @@ -132,7 +131,11 @@ func newConn(cfg connConfig) *Conn {
c.close(errors.New("connection garbage collected"))
})

c.wgGo(c.timeoutLoop)
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.timeoutLoop()
}()

return c
}
Expand Down Expand Up @@ -163,10 +166,12 @@ func (c *Conn) close(err error) {
// closeErr.
c.rwc.Close()

c.wgGo(func() {
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.msgWriter.close()
c.msgReader.close()
})
}()
}

func (c *Conn) timeoutLoop() {
Expand All @@ -183,9 +188,11 @@ func (c *Conn) timeoutLoop() {

case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
c.wgGo(func() {
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.writeError(StatusPolicyViolation, errors.New("read timed out"))
})
}()
case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
return
Expand Down Expand Up @@ -302,15 +309,3 @@ func (m *mu) unlock() {
type noCopy struct{}

func (*noCopy) Lock() {}

func (c *Conn) wgGo(fn func()) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
fn()
}()
}

func (c *Conn) wgWait() {
c.wg.Wait()
}
12 changes: 10 additions & 2 deletions ws_js.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type Conn struct {
// read limit for a message in bytes.
msgReadLimit xsync.Int64

wg sync.WaitGroup
closingMu sync.Mutex
isReadClosed xsync.Int64
closeOnce sync.Once
Expand Down Expand Up @@ -223,6 +224,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
// or the connection is closed.
// It thus performs the full WebSocket close handshake.
func (c *Conn) Close(code StatusCode, reason string) error {
defer c.wg.Wait()
err := c.exportedClose(code, reason)
if err != nil {
return fmt.Errorf("failed to close WebSocket: %w", err)
Expand All @@ -236,6 +238,7 @@ func (c *Conn) Close(code StatusCode, reason string) error {
// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
// a WebSocket without the close handshake.
func (c *Conn) CloseNow() error {
defer c.wg.Wait()
return c.Close(StatusGoingAway, "")
}

Expand Down Expand Up @@ -388,10 +391,15 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.isReadClosed.Store(1)

ctx, cancel := context.WithCancel(ctx)
c.wg.Add(1)
go func() {
defer c.CloseNow()
defer c.wg.Done()
defer cancel()
c.read(ctx)
c.Close(StatusPolicyViolation, "unexpected data message")
_, _, err := c.read(ctx)
if err != nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
Expand Down

0 comments on commit d91a212

Please sign in to comment.