From a94999fb3a308b562b13c85f4d458564adea9147 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 13 Oct 2023 18:28:04 -0700 Subject: [PATCH] close: Implement CloseNow Closes #384 --- close.go | 15 ++++++++++++++- conn.go | 3 +++ conn_test.go | 13 +++++++++++++ export_test.go | 2 ++ 4 files changed, 32 insertions(+), 1 deletion(-) diff --git a/close.go b/close.go index 1e13ca73..25160ee1 100644 --- a/close.go +++ b/close.go @@ -102,6 +102,19 @@ func (c *Conn) Close(code StatusCode, reason string) error { return c.closeHandshake(code, reason) } +// CloseNow closes the WebSocket connection without attempting a close handshake. +// Use When you do not want the overhead of the close handshake. +func (c *Conn) CloseNow() (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + if c.isClosed() { + return errClosed + } + + c.close(nil) + return c.closeErr +} + func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") @@ -265,7 +278,7 @@ func (c *Conn) setCloseErr(err error) { } func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil { + if c.closeErr == nil && err != nil { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } } diff --git a/conn.go b/conn.go index 78eaad82..3713b1f8 100644 --- a/conn.go +++ b/conn.go @@ -147,6 +147,9 @@ func (c *Conn) close(err error) { if c.isClosed() { return } + if err == nil { + err = c.rwc.Close() + } c.setCloseErrLocked(err) close(c.closed) runtime.SetFinalizer(c, nil) diff --git a/conn_test.go b/conn_test.go index 7a6a0c39..50b844b9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -295,6 +295,19 @@ func TestConn(t *testing.T) { err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) }) + + t.Run("CloseNow", func(t *testing.T) { + _, c1, c2 := newConnTest(t, nil, nil) + + err1 := c1.CloseNow() + err2 := c2.CloseNow() + assert.Success(t, err1) + assert.Success(t, err2) + err1 = c1.CloseNow() + err2 = c2.CloseNow() + assert.ErrorIs(t, websocket.ErrClosed, err1) + assert.ErrorIs(t, websocket.ErrClosed, err2) + }) } func TestWasm(t *testing.T) { diff --git a/export_test.go b/export_test.go index 8731b6d8..114796d0 100644 --- a/export_test.go +++ b/export_test.go @@ -23,3 +23,5 @@ func (c *Conn) RecordBytesRead() *int { })) return &bytesRead } + +var ErrClosed = errClosed