Skip to content

Commit

Permalink
Merge pull request #254 from nhooyr/netconn-readlimit
Browse files Browse the repository at this point in the history
netconn.go: Disable read limit on WebSocket
  • Loading branch information
nhooyr committed Jan 9, 2021
2 parents 085d46c + 11af7f8 commit 642a013
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
31 changes: 31 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,37 @@ func TestConn(t *testing.T) {
}
})

t.Run("netConn/readLimit", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)

n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)

s := strings.Repeat("papa", 1 << 20)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte(s))
if err != nil {
return err
}
return n2.Close()
})

b, err := ioutil.ReadAll(n1)
assert.Success(t, err)

_, err = n1.Read(nil)
assert.Equal(t, "read error", err, io.EOF)

select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

assert.Equal(t, "read msg", s, string(b))
})

t.Run("wsjson", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)

Expand Down
4 changes: 4 additions & 0 deletions netconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ import (
//
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading.
//
// Furthermore, the ReadLimit is set to -1 to disable it.
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
c.SetReadLimit(-1)

nc := &netConn{
c: c,
msgType: msgType,
Expand Down
21 changes: 17 additions & 4 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,16 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
//
// Set to -1 to disable.
func (c *Conn) SetReadLimit(n int64) {
// We add read one more byte than the limit in case
// there is a fin frame that needs to be read.
c.msgReader.limitReader.limit.Store(n + 1)
if n >= 0 {
// We read one more byte than the limit in case
// there is a fin frame that needs to be read.
n++
}

c.msgReader.limitReader.limit.Store(n)
}

const defaultReadLimit = 32768
Expand Down Expand Up @@ -455,7 +461,11 @@ func (lr *limitReader) reset(r io.Reader) {
}

func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n <= 0 {
if lr.n < 0 {
return lr.r.Read(p)
}

if lr.n == 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
lr.c.writeError(StatusMessageTooBig, err)
return 0, err
Expand All @@ -466,6 +476,9 @@ func (lr *limitReader) Read(p []byte) (int, error) {
}
n, err := lr.r.Read(p)
lr.n -= int64(n)
if lr.n < 0 {
lr.n = 0
}
return n, err
}

Expand Down

0 comments on commit 642a013

Please sign in to comment.