Skip to content

Commit

Permalink
kyber: remove division by q in ciphertext compression
Browse files Browse the repository at this point in the history
On some platforms, division by q leaks some information on the
ciphertext by its timing. If a keypair is reused, and an attacker has access to
a decapsulation oracle, this reveals information on the private key.
This is known as "kyberslash2".

Note that this does not affect to the typical ephemeral usage in TLS.
  • Loading branch information
bwesterb committed Dec 30, 2023
1 parent 899732a commit 5950957
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 10 deletions.
28 changes: 18 additions & 10 deletions pke/kyber/internal/common/poly.go
Expand Up @@ -166,7 +166,7 @@ func (p *Poly) CompressMessageTo(m []byte) {

// Set p to Decompress_q(m, 1).
//
// Assumes d is in {3, 4, 5, 10, 11}. p will be normalized.
// Assumes d is in {4, 5, 10, 11}. p will be normalized.
func (p *Poly) Decompress(m []byte, d int) {
// Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
// = ⌊(q/2ᵈ)x+½⌋
Expand Down Expand Up @@ -244,20 +244,28 @@ func (p *Poly) Decompress(m []byte, d int) {

// Writes Compress_q(p, d) to m.
//
// Assumes p is normalized and d is in {3, 4, 5, 10, 11}.
// Assumes p is normalized and d is in {4, 5, 10, 11}.
func (p *Poly) CompressTo(m []byte, d int) {
// Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
// = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
// = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
// = DIV((x << d) + q/2, q) & ((1<<d) - 1)
//
// We approximate DIV(x, q) by computing (x*a)>>e, where a/(2^e) ≈ 1/q.
// For d in {10,11} we use 20,642,679/2^36, which computes division by x/q
// correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably.
// For d in {4,5} we use 315/2^20, which doesn't compute division by x/q
// correctly for all inputs, but it's close enough that the end result
// of the compression is correct. The advantage is that we do not need
// to use a 64-bit intermediate value.
switch d {
case 4:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
uint32(Q)) & ((1 << 4) - 1)
t[j] = uint16((((uint32(p[8*i+j])<<4)+uint32(Q)/2)*315)>>
20) & ((1 << 4) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<4)
m[idx+1] = byte(t[2]) | byte(t[3]<<4)
Expand All @@ -271,8 +279,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
uint32(Q)) & ((1 << 5) - 1)
t[j] = uint16((((uint32(p[8*i+j])<<5)+uint32(Q)/2)*315)>>
20) & ((1 << 5) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<5)
m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
Expand All @@ -287,8 +295,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/4; i++ {
for j := 0; j < 4; j++ {
t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
uint32(Q)) & ((1 << 10) - 1)
t[j] = uint16((uint64((uint32(p[4*i+j])<<10)+uint32(Q)/2)*
20642679)>>36) & ((1 << 10) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
Expand All @@ -302,8 +310,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
uint32(Q)) & ((1 << 11) - 1)
t[j] = uint16((uint64((uint32(p[8*i+j])<<11)+uint32(Q)/2)*
20642679)>>36) & ((1 << 11) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
Expand Down
104 changes: 104 additions & 0 deletions pke/kyber/internal/common/poly_test.go
@@ -1,6 +1,7 @@
package common

import (
"bytes"
"crypto/rand"
"fmt"
"testing"
Expand Down Expand Up @@ -273,3 +274,106 @@ func TestNormalizeAgainstGeneric(t *testing.T) {
}
}
}

Check failure on line 276 in pke/kyber/internal/common/poly_test.go

View workflow job for this annotation

GitHub Actions / Go-1.21.1/amd64

File is not `gofumpt`-ed (gofumpt)
func (p *Poly) OldCompressTo(m []byte, d int) {
switch d {
case 4:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
uint32(Q)) & ((1 << 4) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<4)
m[idx+1] = byte(t[2]) | byte(t[3]<<4)
m[idx+2] = byte(t[4]) | byte(t[5]<<4)
m[idx+3] = byte(t[6]) | byte(t[7]<<4)
idx += 4
}

case 5:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
uint32(Q)) & ((1 << 5) - 1)
}
m[idx] = byte(t[0]) | byte(t[1]<<5)
m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
idx += 5
}

case 10:
var t [4]uint16
idx := 0
for i := 0; i < N/4; i++ {
for j := 0; j < 4; j++ {
t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
uint32(Q)) & ((1 << 10) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
m[idx+4] = byte(t[3] >> 2)
idx += 5
}
case 11:
var t [8]uint16
idx := 0
for i := 0; i < N/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
uint32(Q)) & ((1 << 11) - 1)
}
m[idx] = byte(t[0])
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
m[idx+3] = byte(t[2] >> 2)
m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
m[idx+7] = byte(t[5] >> 1)
m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
m[idx+10] = byte(t[7] >> 3)
idx += 11
}
default:
panic("unsupported d")
}
}

func TestCompressFullInputFirstCoeff(t *testing.T) {
for _, d := range []int{4, 5, 10, 11} {
d := d
t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
var p, q Poly
bound := (Q + (1 << uint(d))) >> uint(d+1)
buf := make([]byte, (N*d-1)/8+1)
buf2 := make([]byte, len(buf))
for i := int16(0); i < Q; i++ {
p[0] = i
p.CompressTo(buf, d)
p.OldCompressTo(buf2, d)
if !bytes.Equal(buf, buf2) {
t.Fatalf("%d", i)
}
q.Decompress(buf, d)
diff := sModQ(p[0] - q[0])
if diff < 0 {
diff = -diff
}
if diff > bound {
t.Logf("%v\n", buf)
t.Fatalf("|%d - %d mod^± q| = %d > %d",
p[0], q[0], diff, bound)
}
}
})
}
}

0 comments on commit 5950957

Please sign in to comment.