From 33575b0446cd3640a0a54f5acd438e406b4d3c12 Mon Sep 17 00:00:00 2001 From: armfazh Date: Wed, 13 Mar 2024 17:00:25 -0700 Subject: [PATCH] Fixes unmarshaling KEM keys when passing a larger buffer fo data. --- hpke/hybridkem.go | 12 ++++++++---- hpke/shortkem.go | 15 +++++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/hpke/hybridkem.go b/hpke/hybridkem.go index 74e1ea6f..30873222 100644 --- a/hpke/hybridkem.go +++ b/hpke/hybridkem.go @@ -200,11 +200,13 @@ func (h hybridKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { } func (h hybridKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) { - skA, err := h.kemA.UnmarshalBinaryPrivateKey(data[0:h.kemA.PrivateKeySize()]) + lenA := h.kemA.PrivateKeySize() + skA, err := h.kemA.UnmarshalBinaryPrivateKey(data[0:lenA]) if err != nil { return nil, err } - skB, err := h.kemB.UnmarshalBinaryPrivateKey(data[h.kemA.PrivateKeySize():]) + lenB := h.kemB.PrivateKeySize() + skB, err := h.kemB.UnmarshalBinaryPrivateKey(data[lenA : lenA+lenB]) if err != nil { return nil, err } @@ -216,11 +218,13 @@ func (h hybridKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error } func (h hybridKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) { - pkA, err := h.kemA.UnmarshalBinaryPublicKey(data[0:h.kemA.PublicKeySize()]) + lenA := h.kemA.PublicKeySize() + pkA, err := h.kemA.UnmarshalBinaryPublicKey(data[0:lenA]) if err != nil { return nil, err } - pkB, err := h.kemB.UnmarshalBinaryPublicKey(data[h.kemA.PublicKeySize():]) + lenB := h.kemB.PublicKeySize() + pkB, err := h.kemB.UnmarshalBinaryPublicKey(data[lenA : lenA+lenB]) if err != nil { return nil, err } diff --git a/hpke/shortkem.go b/hpke/shortkem.go index e5c55e99..71d7073e 100644 --- a/hpke/shortkem.go +++ b/hpke/shortkem.go @@ -53,6 +53,7 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { bitmask = 0x01 } + Nsk := s.PrivateKeySize() dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed) var bytes []byte ctr := 0 @@ -64,14 +65,12 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { dkpPrk, []byte("candidate"), []byte{byte(ctr)}, - uint16(s.byteSize()), + uint16(Nsk), ) bytes[0] &= bitmask skBig.SetBytes(bytes) } - l := s.PrivateKeySize() - sk := &shortKEMPrivKey{s, make([]byte, l), nil} - copy(sk.priv[l-len(bytes):], bytes) + sk := &shortKEMPrivKey{s, bytes, nil} return sk.Public(), sk } @@ -87,7 +86,7 @@ func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) return nil, ErrInvalidKEMPrivateKey } sk := &shortKEMPrivKey{s, make([]byte, l), nil} - copy(sk.priv[l-len(data):l], data[:l]) + copy(sk.priv, data[:l]) if !sk.validate() { return nil, ErrInvalidKEMPrivateKey } @@ -96,7 +95,11 @@ func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) } func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) { - x, y := elliptic.Unmarshal(s, data) + l := s.PublicKeySize() + if len(data) < l { + return nil, ErrInvalidKEMPublicKey + } + x, y := elliptic.Unmarshal(s, data[:l]) if x == nil { return nil, ErrInvalidKEMPublicKey }