Skip to content

Commit

Permalink
compress.go: Fix context takeover
Browse files Browse the repository at this point in the history
  • Loading branch information
nhooyr committed Oct 13, 2023
1 parent 4e15d75 commit a02cbef
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 37 deletions.
1 change: 1 addition & 0 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi

if strings.HasPrefix(p, "client_max_window_bits") {
// We cannot adjust the read sliding window so cannot make use of this.
// By not responding to it, we tell the client we're ignoring it.
continue
}

Expand Down
4 changes: 2 additions & 2 deletions ci/bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
set -eu
cd -- "$(dirname "$0")/.."

go test --bench=. "$@" ./...
go test --run=^$ --bench=. "$@" ./...
(
cd ./internal/thirdparty
go test --bench=. "$@" ./...
go test --run=^$ --bench=. "$@" ./...
)
16 changes: 6 additions & 10 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
CompressionDisabled CompressionMode = iota

// CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection.
// It reusing the sliding window from previous messages.
// It reuses the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient.
// It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover.
//
Expand Down Expand Up @@ -80,7 +80,7 @@ func (copts *compressionOptions) setHeader(h http.Header) {
// They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps
// trying to return more bytes.
// trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff"

type trimLastFourBytesWriter struct {
Expand Down Expand Up @@ -201,23 +201,19 @@ func (sw *slidingWindow) init(n int) {
}

p := slidingWindowPool(n)
buf, ok := p.Get().(*[]byte)
sw2, ok := p.Get().(*slidingWindow)
if ok {
sw.buf = (*buf)[:0]
*sw = *sw2
} else {
sw.buf = make([]byte, 0, n)
}
}

func (sw *slidingWindow) close() {
if sw.buf == nil {
return
}

sw.buf = sw.buf[:0]
swPoolMu.Lock()
swPool[cap(sw.buf)].Put(&sw.buf)
swPool[cap(sw.buf)].Put(sw)
swPoolMu.Unlock()
sw.buf = nil
}

func (sw *slidingWindow) write(p []byte) {
Expand Down
1 change: 1 addition & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,5 @@ func (m *mu) unlock() {
}

type noCopy struct{}

func (*noCopy) Lock() {}
4 changes: 2 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func TestConn(t *testing.T) {

t.Run("HTTPClient.Timeout", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
HTTPClient: &http.Client{Timeout: time.Second*5},
HTTPClient: &http.Client{Timeout: time.Second * 5},
}, nil)

tt.goEchoLoop(c2)
Expand Down Expand Up @@ -458,7 +458,7 @@ func BenchmarkConn(b *testing.B) {

typ, r, err := c1.Reader(bb.ctx)
if err != nil {
b.Fatal(err)
b.Fatal(i, err)
}
if websocket.MessageText != typ {
assert.Equal(b, "data type", websocket.MessageText, typ)
Expand Down
3 changes: 2 additions & 1 deletion dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/util"
)

func TestBadDials(t *testing.T) {
Expand All @@ -27,7 +28,7 @@ func TestBadDials(t *testing.T) {
name string
url string
opts *DialOptions
rand readerFunc
rand util.ReaderFunc
nilCtx bool
}{
{
Expand Down
6 changes: 4 additions & 2 deletions export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

package websocket

import "nhooyr.io/websocket/internal/util"

func (c *Conn) RecordBytesWritten() *int {
var bytesWritten int
c.bw.Reset(writerFunc(func(p []byte) (int, error) {
c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) {
bytesWritten += len(p)
return c.rwc.Write(p)
}))
Expand All @@ -14,7 +16,7 @@ func (c *Conn) RecordBytesWritten() *int {

func (c *Conn) RecordBytesRead() *int {
var bytesRead int
c.br.Reset(readerFunc(func(p []byte) (int, error) {
c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) {
n, err := c.rwc.Read(p)
bytesRead += n
return n, err
Expand Down
7 changes: 7 additions & 0 deletions internal/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@ type WriterFunc func(p []byte) (int, error)
func (f WriterFunc) Write(p []byte) (int, error) {
return f(p)
}

// ReaderFunc is used to implement one off io.Readers.
type ReaderFunc func(p []byte) (int, error)

func (f ReaderFunc) Read(p []byte) (int, error) {
return f(p)
}
3 changes: 2 additions & 1 deletion internal/xsync/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package xsync

import (
"fmt"
"runtime/debug"
)

// Go allows running a function in another goroutine
Expand All @@ -13,7 +14,7 @@ func Go(fn func() error) <-chan error {
r := recover()
if r != nil {
select {
case errs <- fmt.Errorf("panic in go fn: %v", r):
case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()):
default:
}
}
Expand Down
27 changes: 16 additions & 11 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/util"
"nhooyr.io/websocket/internal/xsync"
)

Expand Down Expand Up @@ -101,13 +102,20 @@ func newMsgReader(c *Conn) *msgReader {

func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
if mr.dict == nil {
mr.dict = &slidingWindow{}
}
mr.dict.init(32768)
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
}

mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
if mr.flateContextTakeover() {
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
} else {
mr.flateReader = getFlateReader(mr.flateBufio, nil)
}
mr.limitReader.r = mr.flateReader
mr.flateTail.Reset(deflateMessageTail)
}
Expand All @@ -122,7 +130,10 @@ func (mr *msgReader) putFlateReader() {
func (mr *msgReader) close() {
mr.c.readMu.forceLock()
mr.putFlateReader()
mr.dict.close()
if mr.dict != nil {
mr.dict.close()
mr.dict = nil
}
if mr.flateBufio != nil {
putBufioReader(mr.flateBufio)
}
Expand Down Expand Up @@ -348,14 +359,14 @@ type msgReader struct {
flateBufio *bufio.Reader
flateTail strings.Reader
limitReader *limitReader
dict slidingWindow
dict *slidingWindow

fin bool
payloadLength int64
maskKey uint32

// readerFunc(mr.Read) to avoid continuous allocations.
readFunc readerFunc
// util.ReaderFunc(mr.Read) to avoid continuous allocations.
readFunc util.ReaderFunc
}

func (mr *msgReader) reset(ctx context.Context, h header) {
Expand Down Expand Up @@ -484,9 +495,3 @@ func (lr *limitReader) Read(p []byte) (int, error) {
}
return n, err
}

type readerFunc func(p []byte) (int, error)

func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
11 changes: 3 additions & 8 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"compress/flate"

"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/util"
)

// Writer returns a writer bounded by the context that will write
Expand Down Expand Up @@ -93,7 +94,7 @@ func newMsgWriterState(c *Conn) *msgWriterState {
func (mw *msgWriterState) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write),
w: util.WriterFunc(mw.write),
}
}

Expand Down Expand Up @@ -380,17 +381,11 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
return n, nil
}

type writerFunc func(p []byte) (int, error)

func (f writerFunc) Write(p []byte) (int, error) {
return f(p)
}

// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte
bw.Reset(writerFunc(func(p2 []byte) (int, error) {
bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
Expand Down
1 change: 1 addition & 0 deletions ws_js.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,4 +566,5 @@ func (m *mu) unlock() {
}

type noCopy struct{}

func (*noCopy) Lock() {}

0 comments on commit a02cbef

Please sign in to comment.