diff --git a/xof/k12/k12.go b/xof/k12/k12.go index 3c503a18..f696682d 100644 --- a/xof/k12/k12.go +++ b/xof/k12/k12.go @@ -8,6 +8,7 @@ package k12 import ( "encoding/binary" + "sync" "github.com/cloudflare/circl/internal/sha3" "github.com/cloudflare/circl/simd/keccakf1600" @@ -22,7 +23,10 @@ const chunkSize = 8192 // aka B // If we have a fast TurboSHAKE128 available, we buffer chunks until we have // enough to do the parallel TurboSHAKE128. If not, we absorb directly into // a separate TurboSHAKE128 state. +// If requested by the user, will spin up worker threads to compute +// on multiple threads at the same time. +// State stores the intermediate state computing a Kangaroo12 hash. type State struct { initialTodo int // Bytes left to absorb for the first chunk. @@ -45,11 +49,68 @@ type State struct { // a fast parallel TurboSHAKE128, viz when lanes == 1. leaf *sha3.State - lanes uint8 // number of TurboSHAKE128s to compute in parallel + workers int // number of parallel workers; 1 if in single-threaded mode. + lanes uint8 // number of TurboSHAKE128s to compute in parallel + maxWriteSize int // cached return of MaxWriteSize() + + // nil if absorbing first chunk or if operating in single-threaded mode, + // snd otherwise contains all the buffers and locks to deal with the + // multithreaded computation. + w *workersState +} + +type result struct { + data []byte // data to be hashed + hashes []byte // result + done bool +} + +type workersState struct { + // Ringbuffer into which the results (the chunk hashes "CV_i") are + // written by the workers. + ring []result + + // Reader offset in ring: the first result we're waiting for that + // hasn't come back yet, or wOff if the ring buffer is empty. + rOff int + + // Writer offset in ring: the first free result slot, or equal to + // rOff-1 modulo len(ring) if the ring buffer is "full". For simplicity, + // we always leave one free slot to distinguish between an empty and + // full buffer. + wOff int + + // Task offset in ring: the last slot that has been picked up by a worker. + // Thus tOff == wOff when all tasks have been picked up. + tOff int + + // Used to wait on the workers finishing up after the work channel + // is closed when the final task is dispatched. + wg sync.WaitGroup + + // Used to wait on when tOff == wOff + taskCond *sync.Cond + + // Number of works waiting on taskCond + taskWaiting int + + // Used to wait on when ring is full. + writeSlotCond *sync.Cond + + // Number of workers waiting for a full ring + writeSlotWaiting int + + // True if a worker is writing to the stalk + hashing bool + + mux sync.Mutex + + // True if no more data is going to be written + noMore bool } // NewDraft10 creates a new instance of Kangaroo12 draft version -10. -func NewDraft10(c []byte) State { +func NewDraft10(opts ...Option) State { var lanes byte = 1 if keccakf1600.IsEnabledX4() { @@ -58,19 +119,149 @@ func NewDraft10(c []byte) State { lanes = 2 } - return newDraft10(c, lanes) + o := options{ + lanes: lanes, + workers: 1, + } + + o.apply(opts) + + return newDraft10(o) +} + +type options struct { + workers int + lanes byte + context []byte +} + +// Option to K12, for instance WithContext([]byte("context string")). +type Option func(*options) + +func (o *options) apply(opts []Option) { + for _, opt := range opts { + opt(o) + } } -func newDraft10(c []byte, lanes byte) State { - return State{ - initialTodo: chunkSize, - stalk: sha3.NewTurboShake128(0x07), - context: c, - lanes: lanes, +// WithWorkers sets numbers of parallel threads to use in the computation. +func WithWorkers(workers int) Option { + return func(opts *options) { + opts.workers = workers } } +// WithContext sets the context string used +func WithContext(context []byte) Option { + return func(opts *options) { + opts.context = context + } +} + +func newDraft10(opts options) State { + if opts.workers == 0 { + opts.workers = 1 + } + + mws := int(opts.lanes) * chunkSize * opts.workers + + ret := State{ + initialTodo: chunkSize, + stalk: sha3.NewTurboShake128(0x07), + context: opts.context, + lanes: opts.lanes, + workers: opts.workers, + maxWriteSize: mws, + } + + return ret +} + +func (s *State) worker() { + s.w.mux.Lock() + for { + for s.w.tOff == s.w.wOff && !s.w.noMore { + s.w.taskWaiting++ + s.w.taskCond.Wait() + s.w.taskWaiting-- + } + + if s.w.tOff == s.w.wOff && s.w.noMore { + break + } + + offset := s.w.tOff + s.w.tOff = (s.w.tOff + 1) % len(s.w.ring) + count := 1 + for s.w.tOff != s.w.wOff && count <= 16 { + count++ + s.w.tOff = (s.w.tOff + 1) % len(s.w.ring) + } + + s.w.mux.Unlock() + + for i := 0; i < count; i++ { + switch s.lanes { + case 4: + computeX4( + s.w.ring[(offset+i)%len(s.w.ring)].data, + s.w.ring[(offset+i)%len(s.w.ring)].hashes, + ) + default: + computeX2( + s.w.ring[(offset+i)%len(s.w.ring)].data, + s.w.ring[(offset+i)%len(s.w.ring)].hashes, + ) + } + } + + s.w.mux.Lock() + for i := 0; i < count; i++ { + s.w.ring[(offset+i)%len(s.w.ring)].done = true + } + + if !s.w.hashing { + processed := 0 + s.w.hashing = true + + for s.w.rOff != s.w.wOff && s.w.ring[s.w.rOff].done { + hashOffset := s.w.rOff + s.w.mux.Unlock() + _, _ = s.stalk.Write(s.w.ring[hashOffset].hashes) + s.w.mux.Lock() + + if hashOffset != s.w.rOff { + panic("shouldn't happen") + } + + s.w.ring[s.w.rOff].done = false + s.chunk += uint(s.lanes) + s.w.rOff = (1 + s.w.rOff) % len(s.w.ring) + processed++ + } + + s.w.hashing = false + + if s.w.writeSlotWaiting > 0 && processed > 0 { + s.w.writeSlotCond.Broadcast() + } + } + } + s.w.mux.Unlock() + + s.w.wg.Done() +} + func (s *State) Reset() { + if s.w != nil { + s.w.mux.Lock() + s.w.noMore = true + s.w.taskCond.Broadcast() + s.w.mux.Unlock() + s.w.wg.Wait() + s.w = nil + } + s.initialTodo = chunkSize s.stalk.Reset() s.stalk.SwitchDS(0x07) @@ -79,7 +270,15 @@ func (s *State) Reset() { s.chunk = 0 } +// Clone create a copy of the current state. +// +// Note supported in multithreaded mode (viz. when using the WithWorkers option). func (s *State) Clone() State { + if s.w != nil { + // TODO Do we want to implement this? + panic("Clone not supported with parallel workers") + } + stalk := s.stalk.Clone().(*sha3.State) ret := State{ initialTodo: s.initialTodo, @@ -102,13 +301,35 @@ func (s *State) Clone() State { return ret } -func Draft10Sum(hash []byte, msg []byte, c []byte) { - // TODO Tweak number of lanes depending on the length of the message - s := NewDraft10(c) +func Draft10Sum(hash []byte, msg []byte, opts ...Option) { + // TODO Tweak number of lanes/workers depending on the length of the message + s := NewDraft10(opts...) _, _ = s.Write(msg) _, _ = s.Read(hash) } +// NextWriteSize suggests an favorable size for the buffer passed to the next +// call to Write(). +func (s *State) NextWriteSize() int { + if s.initialTodo != 0 { + return s.initialTodo + } + + if s.offset != 0 { + return len(s.buf) - s.offset + } + + return s.maxWriteSize +} + +// MaxWriteSize is the largest value that will be returned from NextWriteSize(). +// +// This can be used to determine the size for a buffer which will be +// fed into Write(). +func (s *State) MaxWriteSize() int { + return s.maxWriteSize +} + func (s *State) Write(p []byte) (int, error) { written := len(p) @@ -142,6 +363,28 @@ func (s *State) Write(p []byte) (int, error) { } _, _ = s.stalk.Write([]byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) s.stalk.SwitchDS(0x06) + + // Kick of workers, if in multi-threaded mode. + if s.workers != 1 && s.lanes != 1 { + s.w = &workersState{ + ring: make([]result, 64*s.workers+1), + } + s.w.writeSlotCond = sync.NewCond(&s.w.mux) + s.w.taskCond = sync.NewCond(&s.w.mux) + + // TODO Check if it's better to use one single buffer. That reduces + // the number of allotions, but increases the false sharing if + // not done carefully. + for i := 0; i < len(s.w.ring); i++ { + s.w.ring[i].hashes = make([]byte, 32*int(s.lanes)) + s.w.ring[i].data = make([]byte, int(s.lanes)*chunkSize) + } + + s.w.wg.Add(s.workers) + for i := 0; i < s.workers; i++ { + go s.worker() + } + } } // If we're just using one lane, we don't need to cache in a buffer @@ -212,129 +455,169 @@ func (s *State) Write(p []byte) (int, error) { // Absorb a multiple of a multiple of lanes * chunkSize. // Returns the remainder. func (s *State) writeX(p []byte) []byte { + if s.w != nil { + taskSize := int(s.lanes) * chunkSize + s.w.mux.Lock() + for len(p) >= taskSize { + maxCount := len(p) / taskSize + + // Find number of free slots + count := 0 + offset := s.w.wOff + for (offset+count+1)%len(s.w.ring) != s.w.rOff && count < maxCount { + if s.w.ring[(offset+count)%len(s.w.ring)].done { + panic("entry shouldn't be done") + } + count++ + } + + if count == 0 { + // Ring is full; need to wait. + s.w.writeSlotWaiting++ + s.w.writeSlotCond.Wait() + s.w.writeSlotWaiting-- + continue + } + s.w.mux.Unlock() + + for i := 0; i < count; i++ { + copy(s.w.ring[(offset+i)%len(s.w.ring)].data, p[:taskSize]) + p = p[taskSize:] + } + + s.w.mux.Lock() + if s.w.wOff != offset { + panic("multiple writers are not allowed") + } + s.w.wOff = (s.w.wOff + count) % len(s.w.ring) + if s.w.taskWaiting > 0 { + for i := 0; i < count; i++ { + s.w.taskCond.Signal() + } + } + } + s.w.mux.Unlock() + return p + } + switch s.lanes { case 4: - return s.writeX4(p) + var buf [4 * 32]byte + for len(p) >= 4*chunkSize { + computeX4(p, buf[:]) + _, _ = s.stalk.Write(buf[:]) + p = p[chunkSize*4:] + s.chunk += 4 + } default: - return s.writeX2(p) + var buf [2 * 32]byte + for len(p) >= 2*chunkSize { + computeX2(p, buf[:]) + _, _ = s.stalk.Write(buf[:]) + p = p[chunkSize*2:] + s.chunk += 2 + } } + return p } -func (s *State) writeX4(p []byte) []byte { - for len(p) >= 4*chunkSize { - var x4 keccakf1600.StateX4 - a := x4.Initialize(true) - - for offset := 0; offset < 48*168; offset += 168 { - for i := 0; i < 21; i++ { - a[i*4] ^= binary.LittleEndian.Uint64( - p[8*i+offset:], - ) - a[i*4+1] ^= binary.LittleEndian.Uint64( - p[chunkSize+8*i+offset:], - ) - a[i*4+2] ^= binary.LittleEndian.Uint64( - p[chunkSize*2+8*i+offset:], - ) - a[i*4+3] ^= binary.LittleEndian.Uint64( - p[chunkSize*3+8*i+offset:], - ) - } - - x4.Permute() - } +func computeX4(p, out []byte) { + var x4 keccakf1600.StateX4 + a := x4.Initialize(true) - for i := 0; i < 16; i++ { + for offset := 0; offset < 48*168; offset += 168 { + for i := 0; i < 21; i++ { a[i*4] ^= binary.LittleEndian.Uint64( - p[8*i+48*168:], + p[8*i+offset:], ) a[i*4+1] ^= binary.LittleEndian.Uint64( - p[chunkSize+8*i+48*168:], + p[chunkSize+8*i+offset:], ) a[i*4+2] ^= binary.LittleEndian.Uint64( - p[chunkSize*2+8*i+48*168:], + p[chunkSize*2+8*i+offset:], ) a[i*4+3] ^= binary.LittleEndian.Uint64( - p[chunkSize*3+8*i+48*168:], + p[chunkSize*3+8*i+offset:], ) } - a[16*4] ^= 0x0b - a[16*4+1] ^= 0x0b - a[16*4+2] ^= 0x0b - a[16*4+3] ^= 0x0b - a[20*4] ^= 0x80 << 56 - a[20*4+1] ^= 0x80 << 56 - a[20*4+2] ^= 0x80 << 56 - a[20*4+3] ^= 0x80 << 56 - x4.Permute() + } - var buf [32 * 4]byte - for i := 0; i < 4; i++ { - binary.LittleEndian.PutUint64(buf[8*i:], a[4*i]) - binary.LittleEndian.PutUint64(buf[32+8*i:], a[4*i+1]) - binary.LittleEndian.PutUint64(buf[32*2+8*i:], a[4*i+2]) - binary.LittleEndian.PutUint64(buf[32*3+8*i:], a[4*i+3]) - } - - _, _ = s.stalk.Write(buf[:]) - p = p[chunkSize*4:] - s.chunk += 4 + for i := 0; i < 16; i++ { + a[i*4] ^= binary.LittleEndian.Uint64( + p[8*i+48*168:], + ) + a[i*4+1] ^= binary.LittleEndian.Uint64( + p[chunkSize+8*i+48*168:], + ) + a[i*4+2] ^= binary.LittleEndian.Uint64( + p[chunkSize*2+8*i+48*168:], + ) + a[i*4+3] ^= binary.LittleEndian.Uint64( + p[chunkSize*3+8*i+48*168:], + ) } - return p + a[16*4] ^= 0x0b + a[16*4+1] ^= 0x0b + a[16*4+2] ^= 0x0b + a[16*4+3] ^= 0x0b + a[20*4] ^= 0x80 << 56 + a[20*4+1] ^= 0x80 << 56 + a[20*4+2] ^= 0x80 << 56 + a[20*4+3] ^= 0x80 << 56 + + x4.Permute() + + for i := 0; i < 4; i++ { + binary.LittleEndian.PutUint64(out[8*i:], a[4*i]) + binary.LittleEndian.PutUint64(out[32+8*i:], a[4*i+1]) + binary.LittleEndian.PutUint64(out[32*2+8*i:], a[4*i+2]) + binary.LittleEndian.PutUint64(out[32*3+8*i:], a[4*i+3]) + } } -func (s *State) writeX2(p []byte) []byte { +func computeX2(p, out []byte) { // TODO On M2 Pro, 1/3 of the time is spent on this function // and LittleEndian.Uint64 excluding the actual permutation. // Rewriting in assembler might be worthwhile. - for len(p) >= 2*chunkSize { - var x2 keccakf1600.StateX2 - a := x2.Initialize(true) - - for offset := 0; offset < 48*168; offset += 168 { - for i := 0; i < 21; i++ { - a[i*2] ^= binary.LittleEndian.Uint64( - p[8*i+offset:], - ) - a[i*2+1] ^= binary.LittleEndian.Uint64( - p[chunkSize+8*i+offset:], - ) - } + var x2 keccakf1600.StateX2 + a := x2.Initialize(true) - x2.Permute() - } - - for i := 0; i < 16; i++ { + for offset := 0; offset < 48*168; offset += 168 { + for i := 0; i < 21; i++ { a[i*2] ^= binary.LittleEndian.Uint64( - p[8*i+48*168:], + p[8*i+offset:], ) a[i*2+1] ^= binary.LittleEndian.Uint64( - p[chunkSize+8*i+48*168:], + p[chunkSize+8*i+offset:], ) } - a[16*2] ^= 0x0b - a[16*2+1] ^= 0x0b - a[20*2] ^= 0x80 << 56 - a[20*2+1] ^= 0x80 << 56 - x2.Permute() + } - var buf [32 * 2]byte - for i := 0; i < 4; i++ { - binary.LittleEndian.PutUint64(buf[8*i:], a[2*i]) - binary.LittleEndian.PutUint64(buf[32+8*i:], a[2*i+1]) - } - - _, _ = s.stalk.Write(buf[:]) - p = p[chunkSize*2:] - s.chunk += 2 + for i := 0; i < 16; i++ { + a[i*2] ^= binary.LittleEndian.Uint64( + p[8*i+48*168:], + ) + a[i*2+1] ^= binary.LittleEndian.Uint64( + p[chunkSize+8*i+48*168:], + ) } - return p + a[16*2] ^= 0x0b + a[16*2+1] ^= 0x0b + a[20*2] ^= 0x80 << 56 + a[20*2+1] ^= 0x80 << 56 + + x2.Permute() + + for i := 0; i < 4; i++ { + binary.LittleEndian.PutUint64(out[8*i:], a[2*i]) + binary.LittleEndian.PutUint64(out[32+8*i:], a[2*i+1]) + } } func (s *State) Read(p []byte) (int, error) { @@ -355,6 +638,17 @@ func (s *State) Read(p []byte) (int, error) { buf[8] = byte(8 - i) // number of bytes to represent |C| _, _ = s.Write(buf[i:]) + // If we're using parallel workers, mark that we're not writing anymore + // and wait for the jobs to complete. + if s.w != nil { + s.w.mux.Lock() + s.w.noMore = true + s.w.taskCond.Broadcast() + s.w.mux.Unlock() + s.w.wg.Wait() + s.w = nil + } + // We need to write the chunk number if we're past the first chunk. if s.buf != nil { // Write last remaining chunk(s) @@ -394,6 +688,7 @@ func (s *State) Read(p []byte) (int, error) { _, _ = s.stalk.Write(buf[i:]) _, _ = s.stalk.Write([]byte{0xff, 0xff}) } + s.buf = nil } return s.stalk.Read(p) diff --git a/xof/k12/k12_test.go b/xof/k12/k12_test.go index a5be5b05..776927b0 100644 --- a/xof/k12/k12_test.go +++ b/xof/k12/k12_test.go @@ -2,6 +2,7 @@ package k12 import ( "encoding/hex" + "runtime" "testing" ) @@ -16,8 +17,12 @@ func ptn(n int) []byte { } func testK12(t *testing.T, msg []byte, c []byte, l int, want string) { - do := func(lanes byte, writeSize int) { - h := newDraft10(c, lanes) + do := func(lanes byte, writeSize int, workers int) { + h := newDraft10(options{ + context: c, + lanes: lanes, + workers: workers, + }) msg2 := msg for len(msg2) > 0 { to := writeSize @@ -31,13 +36,16 @@ func testK12(t *testing.T, msg []byte, c []byte, l int, want string) { _, _ = h.Read(buf) got := hex.EncodeToString(buf) if want != got { - t.Fatalf("%s != %s (lanes=%d, writeSize=%d )", want, got, lanes, writeSize) + t.Fatalf("%s != %s (lanes=%d, writeSize=%d workers=%d len(msg)=%d)", + want, got, lanes, writeSize, workers, len(msg)) } } for _, lanes := range []byte{1, 2, 4} { - for _, writeSize := range []int{7919, 1024, 8 * 1024} { - do(lanes, writeSize) + for _, workers := range []int{1, 4, runtime.NumCPU()} { + for _, writeSize := range []int{7919, 1024, 8 * 1024, chunkSize * int(lanes)} { + do(lanes, writeSize, workers) + } } } } @@ -71,26 +79,56 @@ func TestK12(t *testing.T) { testK12(t, ptn(3*chunkSize+1), []byte{}, 16, "38cb940999aca742d69dd79298c6051c") } -func BenchmarkK12_100B(b *testing.B) { benchmarkK12(b, 100, 1) } -func BenchmarkK12_10K(b *testing.B) { benchmarkK12(b, 10000, 1) } -func BenchmarkK12_100K(b *testing.B) { benchmarkK12(b, 10000, 10) } -func BenchmarkK12_1M(b *testing.B) { benchmarkK12(b, 10000, 100) } -func BenchmarkK12_10M(b *testing.B) { benchmarkK12(b, 10000, 1000) } +func BenchmarkK12_100B(b *testing.B) { benchmarkK12(b, 1, 100) } +func BenchmarkK12_10K(b *testing.B) { benchmarkK12(b, 1, 10000) } +func BenchmarkK12_100K(b *testing.B) { benchmarkK12(b, 1, 100000) } +func BenchmarkK12_3M(b *testing.B) { benchmarkK12(b, 1, 3276800) } +func BenchmarkK12_32M(b *testing.B) { benchmarkK12(b, 1, 32768000) } +func BenchmarkK12_327M(b *testing.B) { benchmarkK12(b, 1, 327680000) } +func BenchmarkK12_3276M(b *testing.B) { benchmarkK12(b, 1, 3276800000) } + +func BenchmarkK12x2_32M(b *testing.B) { benchmarkK12(b, 2, 32768000) } +func BenchmarkK12x2_327M(b *testing.B) { benchmarkK12(b, 2, 327680000) } +func BenchmarkK12x2_3276M(b *testing.B) { benchmarkK12(b, 2, 3276800000) } + +func BenchmarkK12x4_32M(b *testing.B) { benchmarkK12(b, 4, 32768000) } +func BenchmarkK12x4_327M(b *testing.B) { benchmarkK12(b, 4, 327680000) } +func BenchmarkK12x4_3276M(b *testing.B) { benchmarkK12(b, 4, 6553600000) } + +func BenchmarkK12x8_32M(b *testing.B) { benchmarkK12(b, 8, 32768000) } +func BenchmarkK12x8_327M(b *testing.B) { benchmarkK12(b, 8, 327680000) } +func BenchmarkK12x8_3276M(b *testing.B) { benchmarkK12(b, 8, 6553600000) } + +func BenchmarkK12xCPUs_32M(b *testing.B) { benchmarkK12(b, 0, 32768000) } +func BenchmarkK12xCPUs_327M(b *testing.B) { benchmarkK12(b, 0, 327680000) } +func BenchmarkK12xCPUs_3276M(b *testing.B) { benchmarkK12(b, 0, 6553600000) } + +func benchmarkK12(b *testing.B, workers, size int) { + if workers == 0 { + workers = runtime.NumCPU() + } -func benchmarkK12(b *testing.B, size, num int) { b.StopTimer() - h := NewDraft10([]byte{}) - data := make([]byte, size) + h := NewDraft10(WithWorkers(workers)) + buf := make([]byte, h.MaxWriteSize()) d := make([]byte, 32) - b.SetBytes(int64(size * num)) + b.SetBytes(int64(size)) b.StartTimer() for i := 0; i < b.N; i++ { + todo := size h.Reset() - for j := 0; j < num; j++ { - _, _ = h.Write(data) + + for todo > 0 { + next := h.NextWriteSize() + if next > todo { + next = todo + } + _, _ = h.Write(buf[:next]) + todo -= next } + _, _ = h.Read(d) } } diff --git a/xof/xof.go b/xof/xof.go index 33485cac..fe8fbfce 100644 --- a/xof/xof.go +++ b/xof/xof.go @@ -58,7 +58,7 @@ func (x ID) New() XOF { x, _ := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil) return blake2xs{x} case K12D10: - x := k12.NewDraft10([]byte{}) + x := k12.NewDraft10() return k12d10{&x} default: panic("crypto: requested unavailable XOF function")