diff --git a/conn_test.go b/conn_test.go index 9c85459e..3ca810c5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -208,6 +208,37 @@ func TestConn(t *testing.T) { } }) + t.Run("netConn/readLimit", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) + + s := strings.Repeat("papa", 1 << 20) + errs := xsync.Go(func() error { + _, err := n2.Write([]byte(s)) + if err != nil { + return err + } + return n2.Close() + }) + + b, err := ioutil.ReadAll(n1) + assert.Success(t, err) + + _, err = n1.Read(nil) + assert.Equal(t, "read error", err, io.EOF) + + select { + case err := <-errs: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + assert.Equal(t, "read msg", s, string(b)) + }) + t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) diff --git a/netconn.go b/netconn.go index 1664e29b..c6f8dc13 100644 --- a/netconn.go +++ b/netconn.go @@ -38,7 +38,11 @@ import ( // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. +// +// Furthermore, the ReadLimit is set to -1 to disable it. func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { + c.SetReadLimit(-1) + nc := &netConn{ c: c, msgType: msgType, diff --git a/read.go b/read.go index 87151dcb..c4234f20 100644 --- a/read.go +++ b/read.go @@ -74,10 +74,16 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. +// +// Set to -1 to disable. func (c *Conn) SetReadLimit(n int64) { - // We add read one more byte than the limit in case - // there is a fin frame that needs to be read. - c.msgReader.limitReader.limit.Store(n + 1) + if n >= 0 { + // We read one more byte than the limit in case + // there is a fin frame that needs to be read. + n++ + } + + c.msgReader.limitReader.limit.Store(n) } const defaultReadLimit = 32768 @@ -455,7 +461,11 @@ func (lr *limitReader) reset(r io.Reader) { } func (lr *limitReader) Read(p []byte) (int, error) { - if lr.n <= 0 { + if lr.n < 0 { + return lr.r.Read(p) + } + + if lr.n == 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err @@ -466,6 +476,9 @@ func (lr *limitReader) Read(p []byte) (int, error) { } n, err := lr.r.Read(p) lr.n -= int64(n) + if lr.n < 0 { + lr.n = 0 + } return n, err }