Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hpke: Enforces passing a slice of exact size to UnmarshalBinary for KEM keys #489

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 63 additions & 0 deletions hpke/kem_test.go
@@ -0,0 +1,63 @@
package hpke_test

import (
"fmt"
"testing"

"github.com/cloudflare/circl/hpke"
"github.com/cloudflare/circl/internal/test"
)

func TestKemKeysMarshal(t *testing.T) {
for _, kem := range []hpke.KEM{
hpke.KEM_P256_HKDF_SHA256,
hpke.KEM_P384_HKDF_SHA384,
hpke.KEM_P521_HKDF_SHA512,
hpke.KEM_X25519_HKDF_SHA256,
hpke.KEM_X448_HKDF_SHA512,
hpke.KEM_X25519_KYBER768_DRAFT00,
} {
checkIssue488(t, kem)
}
}

func checkIssue488(t *testing.T, kem hpke.KEM) {
scheme := kem.Scheme()
pk, sk, err := scheme.GenerateKeyPair()
if err != nil {
t.Fatal(err)
}
skBytes, err := sk.MarshalBinary()
test.CheckNoErr(t, err, "marshal private key")
pkBytes, err := pk.MarshalBinary()
test.CheckNoErr(t, err, "marshal public key")

t.Run(fmt.Sprintf("%v/PrivateKey", scheme.Name()), func(t *testing.T) {
N := scheme.PrivateKeySize()
buffer := make([]byte, N+1)
copy(buffer, skBytes)

// passing a buffer larger than the private key size should error (but no panic).
_, err := scheme.UnmarshalBinaryPrivateKey(buffer[:N+1])
test.CheckIsErr(t, err, "unmarshal private key should failed")

// passing a buffer of the exact size must be correct.
gotSk, err := scheme.UnmarshalBinaryPrivateKey(buffer[:N])
test.CheckNoErr(t, err, "unmarshal private key shouldn't fail")
test.CheckOk(sk.Equal(gotSk), "private keys are not equal", t)
})

t.Run(fmt.Sprintf("%v/PublicKey", scheme.Name()), func(t *testing.T) {
N := scheme.PublicKeySize()
buffer := make([]byte, N+1)
copy(buffer, pkBytes)

// passing a buffer larger than the public key size should error (but no panic).
_, err := scheme.UnmarshalBinaryPublicKey(buffer[:N+1])
test.CheckIsErr(t, err, "unmarshal public key should failed")

gotPk, err := scheme.UnmarshalBinaryPublicKey(buffer[:N])
test.CheckNoErr(t, err, "unmarshal public key shouldn't fail")
test.CheckOk(pk.Equal(gotPk), "public keys are not equal", t)
})
}
19 changes: 11 additions & 8 deletions hpke/shortkem.go
Expand Up @@ -53,6 +53,7 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
bitmask = 0x01
}

Nsk := s.PrivateKeySize()
dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
var bytes []byte
ctr := 0
Expand All @@ -64,14 +65,12 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
dkpPrk,
[]byte("candidate"),
[]byte{byte(ctr)},
uint16(s.byteSize()),
uint16(Nsk),
)
bytes[0] &= bitmask
skBig.SetBytes(bytes)
}
l := s.PrivateKeySize()
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(bytes):], bytes)
sk := &shortKEMPrivKey{s, bytes, nil}
return sk.Public(), sk
}

Expand All @@ -83,11 +82,11 @@ func (s shortKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {

func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
l := s.PrivateKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPrivateKey
if len(data) != l {
return nil, kem.ErrPrivKeySize
}
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(data):l], data[:l])
copy(sk.priv, data[:l])
if !sk.validate() {
return nil, ErrInvalidKEMPrivateKey
}
Expand All @@ -96,7 +95,11 @@ func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error)
}

func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
x, y := elliptic.Unmarshal(s, data)
l := s.PublicKeySize()
if len(data) != l {
return nil, kem.ErrPubKeySize
}
x, y := elliptic.Unmarshal(s, data[:l])
if x == nil {
return nil, ErrInvalidKEMPublicKey
}
Expand Down
13 changes: 7 additions & 6 deletions hpke/xkem.go
Expand Up @@ -58,13 +58,14 @@ func (x xKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
if len(seed) != x.SeedSize() {
panic(kem.ErrSeedSize)
}
sk := &xKEMPrivKey{scheme: x, priv: make([]byte, x.size)}
Nsk := x.PrivateKeySize()
sk := &xKEMPrivKey{scheme: x, priv: make([]byte, Nsk)}
dkpPrk := x.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
bytes := x.labeledExpand(
dkpPrk,
[]byte("sk"),
nil,
uint16(x.PrivateKeySize()),
uint16(Nsk),
)
copy(sk.priv, bytes)
return sk.Public(), sk
Expand All @@ -81,8 +82,8 @@ func (x xKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {

func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
l := x.PrivateKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPrivateKey
if len(data) != l {
return nil, kem.ErrPrivKeySize
}
sk := &xKEMPrivKey{x, make([]byte, l), nil}
copy(sk.priv, data[:l])
Expand All @@ -94,8 +95,8 @@ func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {

func (x xKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
l := x.PublicKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPublicKey
if len(data) != l {
return nil, kem.ErrPubKeySize
}
pk := &xKEMPubKey{x, make([]byte, l)}
copy(pk.pub, data[:l])
Expand Down