Skip to content

Commit

Permalink
Merge pull request #326 from wdvxdr1123/patch-simd-mask
Browse files Browse the repository at this point in the history
use simd masking for amd64&arm64
  • Loading branch information
nhooyr committed Feb 22, 2024
2 parents 22c9092 + 2cd18b3 commit 8a54c1b
Show file tree
Hide file tree
Showing 21 changed files with 621 additions and 152 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/daily.yml
Expand Up @@ -50,5 +50,5 @@ jobs:
- run: AUTOBAHN=1 ./ci/test.sh
- uses: actions/upload-artifact@v3
with:
name: coverage.html
name: coverage-dev.html
path: ./ci/out/coverage.html
6 changes: 2 additions & 4 deletions README.md
Expand Up @@ -26,6 +26,7 @@ go get nhooyr.io/websocket
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections
- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm)
- WebSocket masking implemented in assembly for amd64 and arm64 [#326](https://github.com/nhooyr/websocket/issues/326)

## Roadmap

Expand All @@ -36,8 +37,6 @@ See GitHub issues for minor issues but the major future enhancements are:
- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267)
- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246)
- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209)
- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16)
- WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402)

Expand Down Expand Up @@ -121,9 +120,8 @@ Advantages of nhooyr.io/websocket:
- Gorilla requires registering a pong callback before sending a Ping
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- [3.5x](https://github.com/nhooyr/websocket/pull/326#issuecomment-1959470758) faster WebSocket masking implementation in assembly for amd64 and arm64 and [2x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster implementation in pure Go
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
Expand Down
15 changes: 13 additions & 2 deletions ci/bench.sh
Expand Up @@ -2,8 +2,19 @@
set -eu
cd -- "$(dirname "$0")/.."

go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" .
go test --run=^$ --bench=. --benchmem "$@" ./...
# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test
(
cd ./internal/thirdparty
go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" .
go test --run=^$ --bench=. --benchmem "$@" .

GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem
fi
)
4 changes: 4 additions & 0 deletions ci/fmt.sh
Expand Up @@ -18,3 +18,7 @@ npx prettier@3.0.3 \
$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html")

go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go

if [ "${CI-}" ]; then
git diff --exit-code
fi
13 changes: 13 additions & 0 deletions ci/test.sh
Expand Up @@ -11,6 +11,19 @@ cd -- "$(dirname "$0")/.."
go test "$@" ./...
)

(
GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
fi
)


go install github.com/agnivade/wasmbrowsertest@latest
go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
Expand Down
123 changes: 0 additions & 123 deletions frame.go
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"math"
"math/bits"

"nhooyr.io/websocket/internal/errd"
)
Expand Down Expand Up @@ -172,125 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {

return nil
}

// mask applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func mask(key uint32, b []byte) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)

// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401

// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}

// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}

// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}

// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}

// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}

// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}

// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}

return key
}
2 changes: 1 addition & 1 deletion frame_test.go
Expand Up @@ -97,7 +97,7 @@ func Test_mask(t *testing.T) {
key := []byte{0xa, 0xb, 0xc, 0xff}
key32 := binary.LittleEndian.Uint32(key)
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
gotKey32 := mask(key32, p)
gotKey32 := mask(p, key32)

expP := []byte{0, 0, 0, 0x0d, 0x6}
assert.Equal(t, "p", expP, p)
Expand Down
Empty file added go.sum
Empty file.
64 changes: 49 additions & 15 deletions internal/thirdparty/frame_test.go
Expand Up @@ -2,41 +2,54 @@ package thirdparty

import (
"encoding/binary"
"runtime"
"strconv"
"testing"
_ "unsafe"

"github.com/gobwas/ws"
_ "github.com/gorilla/websocket"
_ "github.com/lesismal/nbio/nbhttp/websocket"

_ "nhooyr.io/websocket"
)

func basicMask(maskKey [4]byte, pos int, b []byte) int {
func basicMask(b []byte, maskKey [4]byte, pos int) int {
for i := range b {
b[i] ^= maskKey[pos&3]
pos++
}
return pos & 3
}

//go:linkname maskGo nhooyr.io/websocket.maskGo
func maskGo(b []byte, key32 uint32) int

//go:linkname maskAsm nhooyr.io/websocket.maskAsm
func maskAsm(b *byte, len int, key32 uint32) uint32

//go:linkname nbioMaskBytes github.com/lesismal/nbio/nbhttp/websocket.maskXOR
func nbioMaskBytes(b, key []byte) int

//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes
func gorillaMaskBytes(key [4]byte, pos int, b []byte) int

//go:linkname mask nhooyr.io/websocket.mask
func mask(key32 uint32, b []byte) int

func Benchmark_mask(b *testing.B) {
b.Run(runtime.GOARCH, benchmark_mask)
}

func benchmark_mask(b *testing.B) {
sizes := []int{
2,
3,
4,
8,
16,
32,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
}

Expand All @@ -48,22 +61,34 @@ func Benchmark_mask(b *testing.B) {
name: "basic",
fn: func(b *testing.B, key [4]byte, p []byte) {
for i := 0; i < b.N; i++ {
basicMask(key, 0, p)
basicMask(p, key, 0)
}
},
},

{
name: "nhooyr",
name: "nhooyr-go",
fn: func(b *testing.B, key [4]byte, p []byte) {
key32 := binary.LittleEndian.Uint32(key[:])
b.ResetTimer()

for i := 0; i < b.N; i++ {
maskGo(p, key32)
}
},
},
{
name: "wdvxdr1123-asm",
fn: func(b *testing.B, key [4]byte, p []byte) {
key32 := binary.LittleEndian.Uint32(key[:])
b.ResetTimer()

for i := 0; i < b.N; i++ {
mask(key32, p)
maskAsm(&p[0], len(p), key32)
}
},
},

{
name: "gorilla",
fn: func(b *testing.B, key [4]byte, p []byte) {
Expand All @@ -80,16 +105,25 @@ func Benchmark_mask(b *testing.B) {
}
},
},
{
name: "nbio",
fn: func(b *testing.B, key [4]byte, p []byte) {
keyb := key[:]
for i := 0; i < b.N; i++ {
nbioMaskBytes(p, keyb)
}
},
},
}

key := [4]byte{1, 2, 3, 4}

for _, size := range sizes {
p := make([]byte, size)
for _, fn := range fns {
b.Run(fn.name, func(b *testing.B) {
for _, size := range sizes {
p := make([]byte, size)

b.Run(strconv.Itoa(size), func(b *testing.B) {
for _, fn := range fns {
b.Run(fn.name, func(b *testing.B) {
b.Run(strconv.Itoa(size), func(b *testing.B) {
b.SetBytes(int64(size))

fn.fn(b, key, p)
Expand Down
4 changes: 3 additions & 1 deletion internal/thirdparty/go.mod
Expand Up @@ -8,6 +8,7 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/gobwas/ws v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/lesismal/nbio v1.3.18
nhooyr.io/websocket v0.0.0-00010101000000-000000000000
)

Expand All @@ -25,6 +26,7 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/lesismal/llib v1.1.12 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
Expand All @@ -34,7 +36,7 @@ require (
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.9.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down

0 comments on commit 8a54c1b

Please sign in to comment.