From df5ea67995a9749742a711b09db443fd52dba217 Mon Sep 17 00:00:00 2001 From: armfazh Date: Wed, 13 Mar 2024 17:00:25 -0700 Subject: [PATCH] Enforces passing slices of the exact size when unmarshaling KEM keys. --- hpke/shortkem.go | 19 +++++++++++-------- hpke/xkem.go | 13 +++++++------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/hpke/shortkem.go b/hpke/shortkem.go index e5c55e99..cea17a97 100644 --- a/hpke/shortkem.go +++ b/hpke/shortkem.go @@ -53,6 +53,7 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { bitmask = 0x01 } + Nsk := s.PrivateKeySize() dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed) var bytes []byte ctr := 0 @@ -64,14 +65,12 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { dkpPrk, []byte("candidate"), []byte{byte(ctr)}, - uint16(s.byteSize()), + uint16(Nsk), ) bytes[0] &= bitmask skBig.SetBytes(bytes) } - l := s.PrivateKeySize() - sk := &shortKEMPrivKey{s, make([]byte, l), nil} - copy(sk.priv[l-len(bytes):], bytes) + sk := &shortKEMPrivKey{s, bytes, nil} return sk.Public(), sk } @@ -83,11 +82,11 @@ func (s shortKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) { l := s.PrivateKeySize() - if len(data) < l { - return nil, ErrInvalidKEMPrivateKey + if len(data) != l { + return nil, kem.ErrPrivKeySize } sk := &shortKEMPrivKey{s, make([]byte, l), nil} - copy(sk.priv[l-len(data):l], data[:l]) + copy(sk.priv, data[:l]) if !sk.validate() { return nil, ErrInvalidKEMPrivateKey } @@ -96,7 +95,11 @@ func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) } func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) { - x, y := elliptic.Unmarshal(s, data) + l := s.PublicKeySize() + if len(data) != l { + return nil, kem.ErrPubKeySize + } + x, y := elliptic.Unmarshal(s, data[:l]) if x == nil { return nil, ErrInvalidKEMPublicKey } diff --git a/hpke/xkem.go b/hpke/xkem.go index f11ab6b3..19d89614 100644 --- a/hpke/xkem.go +++ b/hpke/xkem.go @@ -58,13 +58,14 @@ func (x xKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { if len(seed) != x.SeedSize() { panic(kem.ErrSeedSize) } - sk := &xKEMPrivKey{scheme: x, priv: make([]byte, x.size)} + Nsk := x.PrivateKeySize() + sk := &xKEMPrivKey{scheme: x, priv: make([]byte, Nsk)} dkpPrk := x.labeledExtract([]byte(""), []byte("dkp_prk"), seed) bytes := x.labeledExpand( dkpPrk, []byte("sk"), nil, - uint16(x.PrivateKeySize()), + uint16(Nsk), ) copy(sk.priv, bytes) return sk.Public(), sk @@ -81,8 +82,8 @@ func (x xKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) { l := x.PrivateKeySize() - if len(data) < l { - return nil, ErrInvalidKEMPrivateKey + if len(data) != l { + return nil, kem.ErrPrivKeySize } sk := &xKEMPrivKey{x, make([]byte, l), nil} copy(sk.priv, data[:l]) @@ -94,8 +95,8 @@ func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) { func (x xKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) { l := x.PublicKeySize() - if len(data) < l { - return nil, ErrInvalidKEMPublicKey + if len(data) != l { + return nil, kem.ErrPubKeySize } pk := &xKEMPubKey{x, make([]byte, l)} copy(pk.pub, data[:l])