diff --git a/close.go b/close.go index d78a5442..fe1ced34 100644 --- a/close.go +++ b/close.go @@ -182,6 +182,13 @@ func (c *Conn) waitCloseHandshake() error { return c.readCloseFrameErr } + for i := int64(0); i < c.msgReader.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + for { h, err := c.readLoop(ctx) if err != nil { diff --git a/conn_test.go b/conn_test.go index 3df6c64a..abc1c81d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -308,6 +308,27 @@ func TestConn(t *testing.T) { assert.ErrorIs(t, websocket.ErrClosed, err1) assert.ErrorIs(t, websocket.ErrClosed, err2) }) + + t.Run("MidReadClose", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + tt.goEchoLoop(c2) + + c1.SetReadLimit(131072) + + for i := 0; i < 5; i++ { + err := wstest.Echo(tt.ctx, c1, 131072) + assert.Success(t, err) + } + + err := wsjson.Write(tt.ctx, c1, "four") + assert.Success(t, err) + _, _, err = c1.Reader(tt.ctx) + assert.Success(t, err) + + err = c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + }) } func TestWasm(t *testing.T) {