Skip to content

Commit

Permalink
Make ascon cipher go routine safe
Browse files Browse the repository at this point in the history
Signed-off-by: Monis Khan <mok@microsoft.com>
  • Loading branch information
enj authored and armfazh committed Mar 14, 2023
1 parent 278354d commit a5c5796
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 45 deletions.
91 changes: 46 additions & 45 deletions cipher/ascon/ascon.go
Expand Up @@ -60,7 +60,6 @@ const (
const permA = 12

type Cipher struct {
s [5]uint64
key [3]uint64
mode Mode
}
Expand Down Expand Up @@ -117,10 +116,11 @@ func (a *Cipher) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
ret, out := sliceForAppend(dst, ptLen+TagSize)
ciphertext, tag := out[:ptLen], out[ptLen:]

a.initialize(nonce)
a.assocData(additionalData)
a.procText(plaintext, ciphertext, true)
a.finalize(tag)
var s [5]uint64
a.initialize(nonce, &s)
a.assocData(additionalData, &s)
a.procText(plaintext, ciphertext, true, &s)
a.finalize(tag, &s)

return ret
}
Expand Down Expand Up @@ -150,10 +150,11 @@ func (a *Cipher) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, er
ciphertext, tag0 := ciphertext[:ptLen], ciphertext[ptLen:]
tag1 := (&[TagSize]byte{})[:]

a.initialize(nonce)
a.assocData(additionalData)
a.procText(ciphertext, plaintext, false)
a.finalize(tag1)
var s [5]uint64
a.initialize(nonce, &s)
a.assocData(additionalData, &s)
a.procText(ciphertext, plaintext, false, &s)
a.finalize(tag1, &s)

if subtle.ConstantTimeCompare(tag0, tag1) == 0 {
return nil, ErrDecryption
Expand All @@ -170,44 +171,44 @@ func (a *Cipher) blockSize() int { return abs(int(a.mode)) << 3 }
// permB = 6 for Ascon128 and Ascon80pq, or 8 for Ascon128a.
func (a *Cipher) permB() int { return (abs(int(a.mode)) + 2) << 1 }

func (a *Cipher) initialize(nonce []byte) {
func (a *Cipher) initialize(nonce []byte, s *[5]uint64) {
bcs := uint64(a.blockSize())
pB := uint64(a.permB())
kS := uint64(a.mode.KeySize())

a.s[0] = ((kS * 8) << 56) | ((bcs * 8) << 48) | (permA << 40) | (pB << 32) | a.key[0]
a.s[1] = a.key[1]
a.s[2] = a.key[2]
a.s[3] = binary.BigEndian.Uint64(nonce[0:8])
a.s[4] = binary.BigEndian.Uint64(nonce[8:16])
s[0] = ((kS * 8) << 56) | ((bcs * 8) << 48) | (permA << 40) | (pB << 32) | a.key[0]
s[1] = a.key[1]
s[2] = a.key[2]
s[3] = binary.BigEndian.Uint64(nonce[0:8])
s[4] = binary.BigEndian.Uint64(nonce[8:16])

a.perm(permA)
perm(permA, s)

a.s[2] ^= a.key[0]
a.s[3] ^= a.key[1]
a.s[4] ^= a.key[2]
s[2] ^= a.key[0]
s[3] ^= a.key[1]
s[4] ^= a.key[2]
}

func (a *Cipher) assocData(add []byte) {
func (a *Cipher) assocData(add []byte, s *[5]uint64) {
bcs := a.blockSize()
pB := a.permB()
if len(add) > 0 {
for ; len(add) >= bcs; add = add[bcs:] {
for i := 0; i < bcs; i += 8 {
a.s[i/8] ^= binary.BigEndian.Uint64(add[i : i+8])
s[i/8] ^= binary.BigEndian.Uint64(add[i : i+8])
}
a.perm(pB)
perm(pB, s)
}
for i := 0; i < len(add); i++ {
a.s[i/8] ^= uint64(add[i]) << (56 - 8*(i%8))
s[i/8] ^= uint64(add[i]) << (56 - 8*(i%8))
}
a.s[len(add)/8] ^= uint64(0x80) << (56 - 8*(len(add)%8))
a.perm(pB)
s[len(add)/8] ^= uint64(0x80) << (56 - 8*(len(add)%8))
perm(pB, s)
}
a.s[4] ^= 0x01
s[4] ^= 0x01
}

func (a *Cipher) procText(in, out []byte, enc bool) {
func (a *Cipher) procText(in, out []byte, enc bool, s *[5]uint64) {
bcs := a.blockSize()
pB := a.permB()
mask := uint64(0)
Expand All @@ -218,45 +219,45 @@ func (a *Cipher) procText(in, out []byte, enc bool) {
for ; len(in) >= bcs; in, out = in[bcs:], out[bcs:] {
for i := 0; i < bcs; i += 8 {
inW := binary.BigEndian.Uint64(in[i : i+8])
outW := a.s[i/8] ^ inW
outW := s[i/8] ^ inW
binary.BigEndian.PutUint64(out[i:i+8], outW)

a.s[i/8] = (inW &^ mask) | (outW & mask)
s[i/8] = (inW &^ mask) | (outW & mask)
}
a.perm(pB)
perm(pB, s)
}

mask8 := byte(mask & 0xFF)
for i := 0; i < len(in); i++ {
off := 56 - (8 * (i % 8))
si := byte((a.s[i/8] >> off) & 0xFF)
si := byte((s[i/8] >> off) & 0xFF)
inB := in[i]
outB := si ^ inB
out[i] = outB
ss := inB&^mask8 | outB&mask8
a.s[i/8] = (a.s[i/8] &^ (0xFF << off)) | uint64(ss)<<off
s[i/8] = (s[i/8] &^ (0xFF << off)) | uint64(ss)<<off
}
a.s[len(in)/8] ^= uint64(0x80) << (56 - 8*(len(in)%8))
s[len(in)/8] ^= uint64(0x80) << (56 - 8*(len(in)%8))
}

func (a *Cipher) finalize(tag []byte) {
func (a *Cipher) finalize(tag []byte, s *[5]uint64) {
bcs := a.blockSize()
if a.mode == Ascon80pq {
a.s[bcs/8+0] ^= a.key[0]<<32 | a.key[1]>>32
a.s[bcs/8+1] ^= a.key[1]<<32 | a.key[2]>>32
a.s[bcs/8+2] ^= a.key[2] << 32
s[bcs/8+0] ^= a.key[0]<<32 | a.key[1]>>32
s[bcs/8+1] ^= a.key[1]<<32 | a.key[2]>>32
s[bcs/8+2] ^= a.key[2] << 32
} else {
a.s[bcs/8+0] ^= a.key[1]
a.s[bcs/8+1] ^= a.key[2]
s[bcs/8+0] ^= a.key[1]
s[bcs/8+1] ^= a.key[2]
}

a.perm(permA)
binary.BigEndian.PutUint64(tag[0:8], a.s[3]^a.key[1])
binary.BigEndian.PutUint64(tag[8:16], a.s[4]^a.key[2])
perm(permA, s)
binary.BigEndian.PutUint64(tag[0:8], s[3]^a.key[1])
binary.BigEndian.PutUint64(tag[8:16], s[4]^a.key[2])
}

func (a *Cipher) perm(n int) {
x0, x1, x2, x3, x4 := a.s[0], a.s[1], a.s[2], a.s[3], a.s[4]
func perm(n int, s *[5]uint64) {
x0, x1, x2, x3, x4 := s[0], s[1], s[2], s[3], s[4]
for i := permA - n; i < permA; i++ {
// pC -- addition of constants
x2 ^= uint64((0xF-i)<<4 | i)
Expand Down Expand Up @@ -289,7 +290,7 @@ func (a *Cipher) perm(n int) {
x3 ^= bits.RotateLeft64(x3, -10) ^ bits.RotateLeft64(x3, -17)
x4 ^= bits.RotateLeft64(x4, -7) ^ bits.RotateLeft64(x4, -41)
}
a.s[0], a.s[1], a.s[2], a.s[3], a.s[4] = x0, x1, x2, x3, x4
s[0], s[1], s[2], s[3], s[4] = x0, x1, x2, x3, x4
}

// sliceForAppend takes a slice and a requested number of bytes. It returns a
Expand Down
21 changes: 21 additions & 0 deletions cipher/ascon/ascon_test.go
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"os"
"strconv"
"sync"
"testing"

"github.com/cloudflare/circl/cipher/ascon"
Expand Down Expand Up @@ -181,6 +182,26 @@ func TestAPI(t *testing.T) {
}
test.CheckOk(&ctWithCap[0] == &plaintext[0], "should have same address", t)
})
t.Run("parallel", func(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 1_000; i++ {
wg.Add(1)
go func() {
defer wg.Done()
ciphertext := c.Seal(nil, nonce, pt, nil)
plaintext, err := c.Open(nil, nonce, ciphertext, nil)
if err != nil {
t.Error(err)
}
got := plaintext
want := pt
if !bytes.Equal(got, want) {
test.ReportError(t, got, want)
}
}()
}
wg.Wait()
})
}

func BenchmarkAscon(b *testing.B) {
Expand Down

0 comments on commit a5c5796

Please sign in to comment.