diff --git a/kem/schemes/schemes.go b/kem/schemes/schemes.go index da5f9839..f3d019a3 100644 --- a/kem/schemes/schemes.go +++ b/kem/schemes/schemes.go @@ -29,6 +29,7 @@ import ( "github.com/cloudflare/circl/kem/mlkem/mlkem1024" "github.com/cloudflare/circl/kem/mlkem/mlkem512" "github.com/cloudflare/circl/kem/mlkem/mlkem768" + "github.com/cloudflare/circl/kem/xwing" ) var allSchemes = [...]kem.Scheme{ @@ -49,6 +50,7 @@ var allSchemes = [...]kem.Scheme{ hybrid.Kyber768X448(), hybrid.Kyber1024X448(), hybrid.P256Kyber768Draft00(), + xwing.Scheme(), } var allSchemeNames map[string]kem.Scheme diff --git a/kem/schemes/schemes_test.go b/kem/schemes/schemes_test.go index be4c18a5..0a89eec2 100644 --- a/kem/schemes/schemes_test.go +++ b/kem/schemes/schemes_test.go @@ -163,4 +163,5 @@ func Example_schemes() { // Kyber768-X448 // Kyber1024-X448 // P256Kyber768Draft00 + // X-Wing } diff --git a/kem/xwing/scheme.go b/kem/xwing/scheme.go new file mode 100644 index 00000000..ac1a3f98 --- /dev/null +++ b/kem/xwing/scheme.go @@ -0,0 +1,140 @@ +package xwing + +import ( + "bytes" + cryptoRand "crypto/rand" + "crypto/subtle" + + "github.com/cloudflare/circl/kem" + "github.com/cloudflare/circl/kem/mlkem/mlkem768" +) + +// This file contains the boilerplate code to connect X-Wing to the +// generic KEM API. + +// Returns the generic KEM interface for X-Wing PQ/T hybrid KEM. +func Scheme() kem.Scheme { return &xwing } + +type scheme struct{} + +var xwing scheme + +func (*scheme) Name() string { return "X-Wing" } +func (*scheme) PublicKeySize() int { return PublicKeySize } +func (*scheme) PrivateKeySize() int { return PrivateKeySize } +func (*scheme) SeedSize() int { return SeedSize } +func (*scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize } +func (*scheme) SharedKeySize() int { return SharedKeySize } +func (*scheme) CiphertextSize() int { return CiphertextSize } +func (*PrivateKey) Scheme() kem.Scheme { return &xwing } +func (*PublicKey) Scheme() kem.Scheme { return &xwing } + +func (sch *scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) { + var seed [EncapsulationSeedSize]byte + _, err = cryptoRand.Read(seed[:]) + if err != nil { + return + } + return sch.EncapsulateDeterministically(pk, seed[:]) +} + +func (sch *scheme) EncapsulateDeterministically( + pk kem.PublicKey, seed []byte, +) ([]byte, []byte, error) { + if len(seed) != EncapsulationSeedSize { + return nil, nil, kem.ErrSeedSize + } + pub, ok := pk.(*PublicKey) + if !ok { + return nil, nil, kem.ErrTypeMismatch + } + var ( + ct [CiphertextSize]byte + ss [SharedKeySize]byte + ) + pub.EncapsulateTo(ct[:], ss[:], seed) + return ct[:], ss[:], nil +} + +func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) { + var pk PublicKey + if len(buf) != PublicKeySize { + return nil, kem.ErrPubKeySize + } + + pk.Unpack(buf) + return &pk, nil +} + +func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { + var sk PrivateKey + if len(buf) != PrivateKeySize { + return nil, kem.ErrPrivKeySize + } + + sk.Unpack(buf) + return &sk, nil +} + +func (sk *PrivateKey) MarshalBinary() ([]byte, error) { + var ret [PrivateKeySize]byte + sk.Pack(ret[:]) + return ret[:], nil +} + +func (sk *PrivateKey) Equal(other kem.PrivateKey) bool { + oth, ok := other.(*PrivateKey) + if !ok { + return false + } + return sk.m.Equal(&oth.m) && + subtle.ConstantTimeCompare(oth.x[:], sk.x[:]) == 1 +} + +func (sk *PrivateKey) Public() kem.PublicKey { + var pk PublicKey + pk.m = *(sk.m.Public().(*mlkem768.PublicKey)) + pk.x = sk.xpk + return &pk +} + +func (pk *PublicKey) Equal(other kem.PublicKey) bool { + oth, ok := other.(*PublicKey) + if !ok { + return false + } + return pk.m.Equal(&oth.m) && bytes.Equal(pk.x[:], oth.x[:]) +} + +func (pk *PublicKey) MarshalBinary() ([]byte, error) { + var ret [PublicKeySize]byte + pk.Pack(ret[:]) + return ret[:], nil +} + +func (*scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { + sk, pk := DeriveKeyPair(seed) + return pk, sk +} + +func (sch *scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { + sk, pk, err := GenerateKeyPair(nil) + return pk, sk, err +} + +func (*scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) { + if len(ct) != CiphertextSize { + return nil, kem.ErrCiphertextSize + } + + var ss [SharedKeySize]byte + + priv, ok := sk.(*PrivateKey) + if !ok { + return nil, kem.ErrTypeMismatch + } + + priv.DecapsulateTo(ss[:], ct[:]) + + return ss[:], nil +} diff --git a/kem/xwing/xwing.go b/kem/xwing/xwing.go new file mode 100644 index 00000000..c6d4e55c --- /dev/null +++ b/kem/xwing/xwing.go @@ -0,0 +1,299 @@ +// xwing implements the X-Wing PQ/T hybrid KEM +// +// https://datatracker.ietf.org/doc/draft-connolly-cfrg-xwing-kem +package xwing + +import ( + cryptoRand "crypto/rand" + "errors" + "io" + + "github.com/cloudflare/circl/dh/x25519" + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/kem" + "github.com/cloudflare/circl/kem/mlkem/mlkem768" +) + +// An X-Wing private key. +type PrivateKey struct { + m mlkem768.PrivateKey + x x25519.Key + xpk x25519.Key // cache to prevent recomputation during each decapsulation +} + +// An X-Wing public key. +type PublicKey struct { + m mlkem768.PublicKey + x x25519.Key +} + +const ( + // Size of a seed of a keypair + SeedSize = 32 + + // Size of an X-Wing public key + PublicKeySize = 1216 + + // Size of an X-Wing private key + PrivateKeySize = 2432 + + // Size of the seed passed to EncapsulateTo + EncapsulationSeedSize = 32 + + // Size of the established shared key + SharedKeySize = 32 + + // Size of an X-Wing ciphertext. + CiphertextSize = 1120 +) + +func combiner( + out []byte, + ssm *[mlkem768.SharedKeySize]byte, + ssx *x25519.Key, + ctx *x25519.Key, + pkx *x25519.Key, +) { + h := sha3.New256() + // \./ + // /^\ + _, _ = h.Write([]byte(`\.//^\`)) + _, _ = h.Write(ssm[:]) + _, _ = h.Write(ssx[:]) + _, _ = h.Write(ctx[:]) + _, _ = h.Write(pkx[:]) + _, _ = h.Read(out[:]) +} + +// Packs sk to buf. +// +// Panics if buf is not of size PrivateKeySize +func (sk *PrivateKey) Pack(buf []byte) { + if len(buf) != PrivateKeySize { + panic(kem.ErrPrivKeySize) + } + sk.m.Pack(buf[:mlkem768.PrivateKeySize]) + copy(buf[mlkem768.PrivateKeySize:], sk.x[:]) +} + +// Packs pk to buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Pack(buf []byte) { + if len(buf) != PublicKeySize { + panic(kem.ErrPubKeySize) + } + pk.m.Pack(buf[:mlkem768.PublicKeySize]) + copy(buf[mlkem768.PublicKeySize:], pk.x[:]) +} + +// DeriveKeyPair derives a public/private keypair deterministically +// from the given seed. +// +// Panics if seed is not of length SeedSize. +func DeriveKeyPair(seed []byte) (*PrivateKey, *PublicKey) { + if len(seed) != SeedSize { + panic(kem.ErrSeedSize) + } + + var ( + pk PublicKey + sk PrivateKey + seedm [mlkem768.KeySeedSize]byte + ) + h := sha3.NewShake128() + _, _ = h.Write(seed) + _, _ = h.Read(seedm[:]) + _, _ = h.Read(sk.x[:]) + + pkm, skm := mlkem768.NewKeyFromSeed(seedm[:]) + sk.m = *skm + pk.m = *pkm + + x25519.KeyGen(&pk.x, &sk.x) + sk.xpk = pk.x + + return &sk, &pk +} + +// DeriveKeyPairPacked derives a keypair like DeriveKeyPair, and +// returns them packed. +func DeriveKeyPairPacked(seed []byte) ([]byte, []byte) { + sk, pk := DeriveKeyPair(seed) + var ( + ppk [PublicKeySize]byte + psk [PrivateKeySize]byte + ) + pk.Pack(ppk[:]) + sk.Pack(psk[:]) + return psk[:], ppk[:] +} + +// GenerateKeyPair generates public and private keys using entropy from rand. +// If rand is nil, crypto/rand.Reader will be used. +func GenerateKeyPair(rand io.Reader) (*PrivateKey, *PublicKey, error) { + var seed [SeedSize]byte + if rand == nil { + rand = cryptoRand.Reader + } + _, err := io.ReadFull(rand, seed[:]) + if err != nil { + return nil, nil, err + } + sk, pk := DeriveKeyPair(seed[:]) + return sk, pk, nil +} + +// GenerateKeyPairPacked generates a keypair like GenerateKeyPair, and +// returns them packed. +func GenerateKeyPairPacked(rand io.Reader) ([]byte, []byte, error) { + sk, pk, err := GenerateKeyPair(rand) + if err != nil { + return nil, nil, err + } + var ( + ppk [PublicKeySize]byte + psk [PrivateKeySize]byte + ) + pk.Pack(ppk[:]) + sk.Pack(psk[:]) + return psk[:], ppk[:], nil +} + +// Encapsulate generates a shared key and ciphertext that contains it +// for the public key pk using randomness from seed. +// +// seed may be nil, in which case crypto/rand.Reader is used. +// +// Warning: note that the order of the returned ss and ct matches the +// X-Wing standard, which is the reverse of the Circl KEM API. +// +// Panics if pk is not of size PublicKeySize, or randomness could not +// be read from crypto/rand.Reader +func Encapsulate(pk, seed []byte) (ss, ct []byte) { + var pub PublicKey + pub.Unpack(pk) + ct = make([]byte, CiphertextSize) + ss = make([]byte, SharedKeySize) + pub.EncapsulateTo(ct, ss, seed) + return ss, ct +} + +// Decapsulate computes the shared key which is encapsulated in ct +// for the private key sk. +// +// Panics if sk or ct are not of length PrivateKeySize and CiphertextSize +// respectively. +func Decapsulate(ct, sk []byte) (ss []byte) { + var priv PrivateKey + priv.Unpack(sk) + ss = make([]byte, SharedKeySize) + priv.DecapsulateTo(ss, ct) + return ss +} + +// Raised when passing a byte slice of the wrong size for the shared +// secret to the EncapsulateTo or DecapsulateTo functions. +var ErrSharedKeySize = errors.New("wrong size for shared key") + +// EncapsulateTo generates a shared key and ciphertext that contains it +// for the public key using randomness from seed and writes the shared key +// to ss and ciphertext to ct. +// +// Panics if ss, ct or seed are not of length SharedKeySize, CiphertextSize +// and EncapsulationSeedSize respectively. +// +// seed may be nil, in which case crypto/rand.Reader is used to generate one. +func (pk *PublicKey) EncapsulateTo(ct, ss, seed []byte) { + if seed == nil { + seed = make([]byte, EncapsulationSeedSize) + if _, err := cryptoRand.Read(seed[:]); err != nil { + panic(err) + } + } else { + if len(seed) != EncapsulationSeedSize { + panic(kem.ErrSeedSize) + } + } + + if len(ct) != CiphertextSize { + panic(kem.ErrCiphertextSize) + } + + if len(ss) != SharedKeySize { + panic(ErrSharedKeySize) + } + + var ( + seedm [32]byte + ekx x25519.Key + ctx x25519.Key + ssx x25519.Key + ssm [mlkem768.SharedKeySize]byte + ) + + h := sha3.NewShake128() + _, _ = h.Write(seed) + _, _ = h.Read(seedm[:]) + _, _ = h.Read(ekx[:]) + + x25519.KeyGen(&ctx, &ekx) + x25519.Shared(&ssx, &ekx, &pk.x) + pk.m.EncapsulateTo(ct[:mlkem768.CiphertextSize], ssm[:], seedm[:]) + + combiner(ss, &ssm, &ssx, &ctx, &pk.x) + copy(ct[mlkem768.CiphertextSize:], ctx[:]) +} + +// DecapsulateTo computes the shared key which is encapsulated in ct +// for the private key. +// +// Panics if ct or ss are not of length CiphertextSize and SharedKeySize +// respectively. +func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { + if len(ct) != CiphertextSize { + panic(kem.ErrCiphertextSize) + } + if len(ss) != SharedKeySize { + panic(ErrSharedKeySize) + } + + ctm := ct[:mlkem768.CiphertextSize] + + var ( + ssm [mlkem768.SharedKeySize]byte + ssx x25519.Key + ctx x25519.Key + ) + + copy(ctx[:], ct[mlkem768.CiphertextSize:]) + + sk.m.DecapsulateTo(ssm[:], ctm) + x25519.Shared(&ssx, &sk.x, &ctx) + combiner(ss, &ssm, &ssx, &ctx, &sk.xpk) +} + +// Unpacks pk from buf. +// +// Panics if buf is not of size PublicKeySize. +func (pk *PublicKey) Unpack(buf []byte) { + if len(buf) != PublicKeySize { + panic(kem.ErrPubKeySize) + } + + copy(pk.x[:], buf[mlkem768.PublicKeySize:]) + pk.m.Unpack(buf[:mlkem768.PublicKeySize]) +} + +// Unpacks sk from buf. +// +// Panics if buf is not of size PrivateKeySize. +func (sk *PrivateKey) Unpack(buf []byte) { + if len(buf) != PrivateKeySize { + panic(kem.ErrPrivKeySize) + } + + copy(sk.x[:], buf[mlkem768.PrivateKeySize:]) + x25519.KeyGen(&sk.xpk, &sk.x) + sk.m.Unpack(buf[:mlkem768.PrivateKeySize]) +} diff --git a/kem/xwing/xwing_test.go b/kem/xwing/xwing_test.go new file mode 100644 index 00000000..175a2327 --- /dev/null +++ b/kem/xwing/xwing_test.go @@ -0,0 +1,73 @@ +package xwing + +import ( + "bytes" + "fmt" + "io" + "testing" + + "github.com/cloudflare/circl/internal/sha3" +) + +func writeHex(w io.Writer, prefix string, val interface{}) { + indent := " " + width := 74 + hex := fmt.Sprintf("%x", val) + if len(prefix)+len(hex)+1 < width { + fmt.Fprintf(w, "%s %s\n", prefix, hex) + return + } + fmt.Fprintf(w, "%s\n", prefix) + for len(hex) != 0 { + var toPrint string + if len(hex) < width-len(indent) { + toPrint = hex + hex = "" + } else { + toPrint = hex[:width-len(indent)] + hex = hex[width-len(indent):] + } + fmt.Fprintf(w, "%s%s\n", indent, toPrint) + } +} + +func TestVectors(t *testing.T) { + h := sha3.NewShake128() + w := new(bytes.Buffer) + + for i := 0; i < 3; i++ { + var seed [SeedSize]byte + _, _ = h.Read(seed[:]) + writeHex(w, "seed ", seed) + + sk, pk := DeriveKeyPairPacked(seed[:]) + writeHex(w, "sk ", sk) + writeHex(w, "pk ", pk) + + var eseed [EncapsulationSeedSize]byte + _, _ = h.Read(eseed[:]) + writeHex(w, "eseed ", eseed) + + ss, ct := Encapsulate(pk, eseed[:]) + writeHex(w, "ct ", ct) + writeHex(w, "ss ", ss) + + ss2 := Decapsulate(ct, sk) + if !bytes.Equal(ss, ss2) { + t.Fatal() + } + + fmt.Fprintf(w, "\n") + } + + t.Logf("%s", w.String()) + h.Reset() + _, _ = h.Write(w.Bytes()) + var cs [32]byte + _, _ = h.Read(cs[:]) + got := fmt.Sprintf("%x", cs) + want := "dff9d6258b66060ac402a8faa0114d6a8b683bfa8555eb630b764f2a3a709990" + if got != want { + t.Fatalf("%s ≠ %s", got, want) + } +}