diff --git a/group/group.go b/group/group.go index 71c87e919..5ef5e51f8 100644 --- a/group/group.go +++ b/group/group.go @@ -5,6 +5,7 @@ import ( "encoding" "errors" "io" + "math/big" ) // Params stores the size in bytes of elements and scalars. @@ -105,8 +106,12 @@ type Scalar interface { IsZero() bool // IsEqual returns true if the receiver is equal to x. IsEqual(x Scalar) bool - // SetUint64 set the receiver to x, and returns the receiver. + // SetUint64 sets the receiver to x, and returns the receiver. SetUint64(x uint64) Scalar + // SetBigInt sets the receiver to x, and returns the receiver. + // Warning: operations on big.Int are not constant time. Do not use them + // for cryptography unless you're sure it's safe in your use-case. + SetBigInt(b *big.Int) Scalar // CMov sets the receiver to x if b=1; the receiver is unmodified if b=0; // otherwise panics if b is not 0 or 1. In all the cases, it returns the // receiver. diff --git a/group/ristretto255.go b/group/ristretto255.go index 4ec0baffa..c1312c18f 100644 --- a/group/ristretto255.go +++ b/group/ristretto255.go @@ -5,6 +5,7 @@ import ( _ "crypto/sha512" "fmt" "io" + "math/big" r255 "github.com/bwesterb/go-ristretto" "github.com/cloudflare/circl/expander" @@ -203,10 +204,11 @@ func (e *ristrettoElement) UnmarshalBinary(data []byte) error { return e.p.UnmarshalBinary(data) } -func (s *ristrettoScalar) Group() Group { return Ristretto255 } -func (s *ristrettoScalar) String() string { return conv.BytesLe2Hex(s.s.Bytes()) } -func (s *ristrettoScalar) SetUint64(n uint64) Scalar { s.s.SetUint64(n); return s } -func (s *ristrettoScalar) IsZero() bool { return s.s.IsNonZeroI() == 0 } +func (s *ristrettoScalar) Group() Group { return Ristretto255 } +func (s *ristrettoScalar) String() string { return conv.BytesLe2Hex(s.s.Bytes()) } +func (s *ristrettoScalar) SetUint64(n uint64) Scalar { s.s.SetUint64(n); return s } +func (s *ristrettoScalar) SetBigInt(x *big.Int) Scalar { s.s.SetBigInt(x); return s } +func (s *ristrettoScalar) IsZero() bool { return s.s.IsNonZeroI() == 0 } func (s *ristrettoScalar) IsEqual(x Scalar) bool { return s.s.Equals(&x.(*ristrettoScalar).s) } diff --git a/group/short.go b/group/short.go index 71473e8e0..c5ad3f2cc 100644 --- a/group/short.go +++ b/group/short.go @@ -271,9 +271,10 @@ type wScl struct { k []byte } -func (s *wScl) Group() Group { return s.wG } -func (s *wScl) String() string { return fmt.Sprintf("0x%x", s.k) } -func (s *wScl) SetUint64(n uint64) Scalar { s.fromBig(new(big.Int).SetUint64(n)); return s } +func (s *wScl) Group() Group { return s.wG } +func (s *wScl) String() string { return fmt.Sprintf("0x%x", s.k) } +func (s *wScl) SetUint64(n uint64) Scalar { s.fromBig(new(big.Int).SetUint64(n)); return s } +func (s *wScl) SetBigInt(x *big.Int) Scalar { s.fromBig(x); return s } func (s *wScl) IsZero() bool { return subtle.ConstantTimeCompare(s.k, make([]byte, (s.wG.c.Params().BitSize+7)/8)) == 1 }