Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use simd masking for amd64&arm64 #326

Merged
merged 26 commits into from Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5df0303
mask.go: Use SIMD masking for amd64 and arm64
wdvxdr1123 Jan 24, 2022
cda2170
Refactor and compile masking code again
nhooyr Oct 19, 2023
f5397ae
mask_asm.go: Disable AVX2
nhooyr Oct 19, 2023
14172e5
Benchmark pure go masking algorithm separately from assembly
nhooyr Oct 19, 2023
685a56e
Update README.md to indicate assembly websocket masking
nhooyr Oct 19, 2023
cb7509a
mask_amd64.s: Remove AVX2 fully
nhooyr Oct 19, 2023
3f8c9e0
mask_amd64.s: Minor improvements
nhooyr Oct 19, 2023
367743d
mask_amd64.sh: Cleanup
nhooyr Oct 19, 2023
27f80cb
mask.go: Cleanup assembly and add nbio benchmark
nhooyr Oct 19, 2023
369d641
mask_arm64.s: Cleanup
nhooyr Oct 20, 2023
fb13df2
ci/bench.sh: Benchmark masking on arm64 with QEMU
nhooyr Oct 20, 2023
ecf7dec
ci/bench.sh: Install QEMU on CI
nhooyr Oct 20, 2023
d34e5d4
wsjson: Add json.Encoder vs json.Marshal benchmark
nhooyr Oct 20, 2023
e25d968
ci/bench.sh: Don't profile by default
nhooyr Oct 20, 2023
640e3c2
ci/bench.sh: Try function instead of alias
nhooyr Oct 20, 2023
0596e7a
wsjson: Extend benchmark with multiple sizes
nhooyr Oct 20, 2023
30447a3
ci/bench.sh: Just symlink the expected qemu-aarch64 binary name
nhooyr Oct 20, 2023
f4e61e5
ci/fmt.sh: Error if changes on CI
nhooyr Oct 21, 2023
f533f43
mask.go: Reorganize
nhooyr Oct 21, 2023
a1bb441
ci: Fix dev coverage output
nhooyr Feb 7, 2024
fee3739
mask_asm: Note implementation may not be perfect
nhooyr Feb 7, 2024
68fc887
mask.go: Revert my changes
nhooyr Feb 22, 2024
f62cef3
test.sh: Test assembly masking on arm64
nhooyr Feb 22, 2024
92acb74
internal/xcpu: Vendor golang.org/x/sys/cpu
nhooyr Feb 22, 2024
17e1b86
mask_asm: Disable AVX2
nhooyr Feb 22, 2024
2cd18b3
README.md: Link to assembly benchmark results
nhooyr Feb 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/daily.yml
Expand Up @@ -50,5 +50,5 @@ jobs:
- run: AUTOBAHN=1 ./ci/test.sh
- uses: actions/upload-artifact@v3
with:
name: coverage.html
name: coverage-dev.html
path: ./ci/out/coverage.html
6 changes: 2 additions & 4 deletions README.md
Expand Up @@ -26,6 +26,7 @@ go get nhooyr.io/websocket
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections
- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm)
- WebSocket masking implemented in assembly for amd64 and arm64 [#326](https://github.com/nhooyr/websocket/issues/326)

## Roadmap

Expand All @@ -36,8 +37,6 @@ See GitHub issues for minor issues but the major future enhancements are:
- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267)
- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246)
- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209)
- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16)
- WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402)

Expand Down Expand Up @@ -121,9 +120,8 @@ Advantages of nhooyr.io/websocket:
- Gorilla requires registering a pong callback before sending a Ping
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- [3.5x](https://github.com/nhooyr/websocket/pull/326#issuecomment-1959470758) faster WebSocket masking implementation in assembly for amd64 and arm64 and [2x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster implementation in pure Go
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
Expand Down
15 changes: 13 additions & 2 deletions ci/bench.sh
Expand Up @@ -2,8 +2,19 @@
set -eu
cd -- "$(dirname "$0")/.."

go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" .
go test --run=^$ --bench=. --benchmem "$@" ./...
# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test
(
cd ./internal/thirdparty
go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" .
go test --run=^$ --bench=. --benchmem "$@" .

GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem
fi
)
4 changes: 4 additions & 0 deletions ci/fmt.sh
Expand Up @@ -18,3 +18,7 @@ npx prettier@3.0.3 \
$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html")

go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go

if [ "${CI-}" ]; then
git diff --exit-code
fi
13 changes: 13 additions & 0 deletions ci/test.sh
Expand Up @@ -11,6 +11,19 @@ cd -- "$(dirname "$0")/.."
go test "$@" ./...
)

(
GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
fi
)


go install github.com/agnivade/wasmbrowsertest@latest
go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
Expand Down
123 changes: 0 additions & 123 deletions frame.go
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"math"
"math/bits"

"nhooyr.io/websocket/internal/errd"
)
Expand Down Expand Up @@ -172,125 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {

return nil
}

// mask applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func mask(key uint32, b []byte) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)

// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401

// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}

// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}

// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}

// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}

// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}

// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}

// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}

return key
}
2 changes: 1 addition & 1 deletion frame_test.go
Expand Up @@ -97,7 +97,7 @@ func Test_mask(t *testing.T) {
key := []byte{0xa, 0xb, 0xc, 0xff}
key32 := binary.LittleEndian.Uint32(key)
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
gotKey32 := mask(key32, p)
gotKey32 := mask(p, key32)

expP := []byte{0, 0, 0, 0x0d, 0x6}
assert.Equal(t, "p", expP, p)
Expand Down
Empty file added go.sum
Empty file.
64 changes: 49 additions & 15 deletions internal/thirdparty/frame_test.go
Expand Up @@ -2,41 +2,54 @@ package thirdparty

import (
"encoding/binary"
"runtime"
"strconv"
"testing"
_ "unsafe"

"github.com/gobwas/ws"
_ "github.com/gorilla/websocket"
_ "github.com/lesismal/nbio/nbhttp/websocket"

_ "nhooyr.io/websocket"
)

func basicMask(maskKey [4]byte, pos int, b []byte) int {
func basicMask(b []byte, maskKey [4]byte, pos int) int {
for i := range b {
b[i] ^= maskKey[pos&3]
pos++
}
return pos & 3
}

//go:linkname maskGo nhooyr.io/websocket.maskGo
func maskGo(b []byte, key32 uint32) int

//go:linkname maskAsm nhooyr.io/websocket.maskAsm
func maskAsm(b *byte, len int, key32 uint32) uint32

//go:linkname nbioMaskBytes github.com/lesismal/nbio/nbhttp/websocket.maskXOR
func nbioMaskBytes(b, key []byte) int

//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes
func gorillaMaskBytes(key [4]byte, pos int, b []byte) int

//go:linkname mask nhooyr.io/websocket.mask
func mask(key32 uint32, b []byte) int

func Benchmark_mask(b *testing.B) {
b.Run(runtime.GOARCH, benchmark_mask)
}

func benchmark_mask(b *testing.B) {
sizes := []int{
2,
3,
4,
8,
16,
32,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
}

Expand All @@ -48,22 +61,34 @@ func Benchmark_mask(b *testing.B) {
name: "basic",
fn: func(b *testing.B, key [4]byte, p []byte) {
for i := 0; i < b.N; i++ {
basicMask(key, 0, p)
basicMask(p, key, 0)
}
},
},

{
name: "nhooyr",
name: "nhooyr-go",
fn: func(b *testing.B, key [4]byte, p []byte) {
key32 := binary.LittleEndian.Uint32(key[:])
b.ResetTimer()

for i := 0; i < b.N; i++ {
maskGo(p, key32)
}
},
},
{
name: "wdvxdr1123-asm",
fn: func(b *testing.B, key [4]byte, p []byte) {
key32 := binary.LittleEndian.Uint32(key[:])
b.ResetTimer()

for i := 0; i < b.N; i++ {
mask(key32, p)
maskAsm(&p[0], len(p), key32)
}
},
},

{
name: "gorilla",
fn: func(b *testing.B, key [4]byte, p []byte) {
Expand All @@ -80,16 +105,25 @@ func Benchmark_mask(b *testing.B) {
}
},
},
{
name: "nbio",
fn: func(b *testing.B, key [4]byte, p []byte) {
keyb := key[:]
for i := 0; i < b.N; i++ {
nbioMaskBytes(p, keyb)
}
},
},
}

key := [4]byte{1, 2, 3, 4}

for _, size := range sizes {
p := make([]byte, size)
for _, fn := range fns {
b.Run(fn.name, func(b *testing.B) {
for _, size := range sizes {
p := make([]byte, size)

b.Run(strconv.Itoa(size), func(b *testing.B) {
for _, fn := range fns {
b.Run(fn.name, func(b *testing.B) {
b.Run(strconv.Itoa(size), func(b *testing.B) {
b.SetBytes(int64(size))

fn.fn(b, key, p)
Expand Down
4 changes: 3 additions & 1 deletion internal/thirdparty/go.mod
Expand Up @@ -8,6 +8,7 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/gobwas/ws v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/lesismal/nbio v1.3.18
nhooyr.io/websocket v0.0.0-00010101000000-000000000000
)

Expand All @@ -25,6 +26,7 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/lesismal/llib v1.1.12 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
Expand All @@ -34,7 +36,7 @@ require (
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.9.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down