Skip to content

Commit

Permalink
Adds conditional move and select to group.
Browse files Browse the repository at this point in the history
  • Loading branch information
armfazh committed Aug 1, 2022
1 parent c8971c0 commit 10a0004
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 12 deletions.
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/cloudflare/circl
go 1.16

require (
github.com/bwesterb/go-ristretto v1.2.1
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d
golang.org/x/sys v0.0.0-20220624220833-87e55d714810
github.com/bwesterb/go-ristretto v1.2.2
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10
)
12 changes: 6 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
github.com/bwesterb/go-ristretto v1.2.1 h1:Xd9ZXmjKE2aY8Ub7+4bX7tXsIPsV1pIZaUlJUjI1toE=
github.com/bwesterb/go-ristretto v1.2.1/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
github.com/bwesterb/go-ristretto v1.2.2 h1:S2C0mmSjCLS3H9+zfXoIoKzl+cOncvBvt6pE+zTm5Ms=
github.com/bwesterb/go-ristretto v1.2.2/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220624220833-87e55d714810 h1:rHZQSjJdAI4Xf5Qzeh2bBc5YJIkPFVM6oDtMFYmgws0=
golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
5 changes: 5 additions & 0 deletions group/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type Element interface {
Copy() Element
IsIdentity() bool
IsEqual(Element) bool
CMov(int, Element) Element
CSelect(int, Element, Element) Element
Add(Element, Element) Element
Dbl(Element) Element
Neg(Element) Element
Expand All @@ -53,6 +55,8 @@ type Scalar interface {
Copy() Scalar
IsEqual(Scalar) bool
SetUint64(uint64)
CMov(int, Scalar) Scalar
CSelect(int, Scalar, Scalar) Scalar
Add(Scalar, Scalar) Scalar
Sub(Scalar, Scalar) Scalar
Mul(Scalar, Scalar) Scalar
Expand All @@ -65,4 +69,5 @@ type Scalar interface {
var (
ErrType = errors.New("type mismatch")
ErrUnmarshal = errors.New("error unmarshaling")
ErrSelector = errors.New("group: selector must be 0 or 1")
)
108 changes: 105 additions & 3 deletions group/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ func TestGroup(t *testing.T) {
t.Run(n+"/Neg", func(tt *testing.T) { testNeg(tt, testTimes, g) })
t.Run(n+"/Mul", func(tt *testing.T) { testMul(tt, testTimes, g) })
t.Run(n+"/MulGen", func(tt *testing.T) { testMulGen(tt, testTimes, g) })
t.Run(n+"/CMov", func(tt *testing.T) { testCMov(tt, testTimes, g) })
t.Run(n+"/CSelect", func(tt *testing.T) { testCSelect(tt, testTimes, g) })
t.Run(n+"/Order", func(tt *testing.T) { testOrder(tt, testTimes, g) })
t.Run(n+"/Marshal", func(tt *testing.T) { testMarshal(tt, testTimes, g) })
t.Run(n+"/Scalar", func(tt *testing.T) { testScalar(tt, testTimes, g) })
Expand Down Expand Up @@ -101,6 +103,66 @@ func testMulGen(t *testing.T, testTimes int, g group.Group) {
}
}

func testCMov(t *testing.T, testTimes int, g group.Group) {
P := g.RandomElement(rand.Reader)
Q := g.RandomElement(rand.Reader)

err := test.CheckPanic(func() { P.CMov(0, Q) })
test.CheckIsErr(t, err, "shouldn't fail with 0")
err = test.CheckPanic(func() { P.CMov(1, Q) })
test.CheckIsErr(t, err, "shouldn't fail with 1")
err = test.CheckPanic(func() { P.CMov(2, Q) })
test.CheckNoErr(t, err, "should fail with dif 0,1")

for i := 0; i < testTimes; i++ {
P = g.RandomElement(rand.Reader)
Q = g.RandomElement(rand.Reader)

want := P.Copy()
got := P.CMov(0, Q)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}

want = Q.Copy()
got = P.CMov(1, Q)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}
}
}

func testCSelect(t *testing.T, testTimes int, g group.Group) {
P := g.RandomElement(rand.Reader)
Q := g.RandomElement(rand.Reader)
R := g.RandomElement(rand.Reader)

err := test.CheckPanic(func() { P.CSelect(0, Q, R) })
test.CheckIsErr(t, err, "shouldn't fail with 0")
err = test.CheckPanic(func() { P.CSelect(1, Q, R) })
test.CheckIsErr(t, err, "shouldn't fail with 1")
err = test.CheckPanic(func() { P.CSelect(2, Q, R) })
test.CheckNoErr(t, err, "should fail with dif 0,1")

for i := 0; i < testTimes; i++ {
P := g.RandomElement(rand.Reader)
Q := g.RandomElement(rand.Reader)
R := g.RandomElement(rand.Reader)

want := R.Copy()
got := P.CSelect(0, Q, R)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}

want = Q.Copy()
got = P.CSelect(1, Q, R)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}
}
}

func testOrder(t *testing.T, testTimes int, g group.Group) {
Q := g.NewElement()
order := g.Order()
Expand Down Expand Up @@ -179,16 +241,33 @@ func testMarshal(t *testing.T, testTimes int, g group.Group) {
}

func testScalar(t *testing.T, testTimes int, g group.Group) {
a := g.RandomScalar(rand.Reader)
b := g.RandomScalar(rand.Reader)
c := g.NewScalar()
d := g.NewScalar()
e := g.NewScalar()
f := g.NewScalar()
one := g.NewScalar()
one.SetUint64(1)
params := g.Params()

err := test.CheckPanic(func() { a.CMov(0, b) })
test.CheckIsErr(t, err, "shouldn't fail with 0")
err = test.CheckPanic(func() { a.CMov(1, b) })
test.CheckIsErr(t, err, "shouldn't fail with 1")
err = test.CheckPanic(func() { a.CMov(2, b) })
test.CheckNoErr(t, err, "should fail with dif 0,1")

err = test.CheckPanic(func() { a.CSelect(0, b, c) })
test.CheckIsErr(t, err, "shouldn't fail with 0")
err = test.CheckPanic(func() { a.CSelect(1, b, c) })
test.CheckIsErr(t, err, "shouldn't fail with 1")
err = test.CheckPanic(func() { a.CSelect(2, b, c) })
test.CheckNoErr(t, err, "should fail with dif 0,1")

for i := 0; i < testTimes; i++ {
a := g.RandomScalar(rand.Reader)
b := g.RandomScalar(rand.Reader)
a = g.RandomScalar(rand.Reader)
b = g.RandomScalar(rand.Reader)
c.Add(a, b)
d.Sub(a, b)
e.Mul(c, d)
Expand All @@ -207,9 +286,32 @@ func testScalar(t *testing.T, testTimes int, g group.Group) {
if l := uint(len(enc1)); l != params.ScalarLength {
test.ReportError(t, l, params.ScalarLength)
}

want := c.Copy()
got := c.CMov(0, a)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}

want = b.Copy()
got = d.CMov(1, b)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}

want = b.Copy()
got = e.CSelect(0, a, b)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}

want = a.Copy()
got = f.CSelect(1, a, b)
if !got.IsEqual(want) {
test.ReportError(t, got, want)
}
}

a := g.RandomScalar(rand.Reader)
c.Inv(a)
c.Mul(c, a)
if !one.IsEqual(c) {
Expand Down
36 changes: 36 additions & 0 deletions group/ristretto255.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ func (g ristrettoGroup) HashToScalar(msg, dst []byte) Scalar {

func (e *ristrettoElement) Group() Group { return Ristretto255 }

func (e *ristrettoElement) String() string { return fmt.Sprintf("%x", e.p.Bytes()) }

func (e *ristrettoElement) IsIdentity() bool {
var zero r255.Point
zero.SetZero()
Expand All @@ -147,6 +149,23 @@ func (e *ristrettoElement) Copy() Element {
return &ristrettoElement{*new(r255.Point).Set(&e.p)}
}

func (e *ristrettoElement) CMov(v int, x Element) Element {
if !(v == 0 || v == 1) {
panic(ErrSelector)
}
e.p.ConditionalSet(&x.(*ristrettoElement).p, int32(v))
return e
}

func (e *ristrettoElement) CSelect(v int, x Element, y Element) Element {
if !(v == 0 || v == 1) {
panic(ErrSelector)
}
e.p.ConditionalSet(&x.(*ristrettoElement).p, int32(v))
e.p.ConditionalSet(&y.(*ristrettoElement).p, int32(1-v))
return e
}

func (e *ristrettoElement) Add(x Element, y Element) Element {
e.p.Add(&x.(*ristrettoElement).p, &y.(*ristrettoElement).p)
return e
Expand Down Expand Up @@ -200,6 +219,23 @@ func (s *ristrettoScalar) Copy() Scalar {
return &ristrettoScalar{*new(r255.Scalar).Set(&s.s)}
}

func (s *ristrettoScalar) CMov(v int, x Scalar) Scalar {
if !(v == 0 || v == 1) {
panic(ErrSelector)
}
s.s.ConditionalSet(&x.(*ristrettoScalar).s, int32(v))
return s
}

func (s *ristrettoScalar) CSelect(v int, x Scalar, y Scalar) Scalar {
if !(v == 0 || v == 1) {
panic(ErrSelector)
}
s.s.ConditionalSet(&x.(*ristrettoScalar).s, int32(v))
s.s.ConditionalSet(&y.(*ristrettoScalar).s, int32(1-v))
return s
}

func (s *ristrettoScalar) Add(x Scalar, y Scalar) Scalar {
s.s.Add(&x.(*ristrettoScalar).s, &y.(*ristrettoScalar).s)
return s
Expand Down
72 changes: 72 additions & 0 deletions group/short.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,57 @@ func (e *wElt) Set(a Element) Element {
}

func (e *wElt) Copy() Element { return e.wG.zeroElement().Set(e) }

func (e *wElt) CMov(v int, a Element) Element {
if !(v == 0 || v == 1) {
panic(ErrSelector)
}
aa := e.cvtElt(a)
l := (e.wG.c.Params().BitSize + 7) / 8
bufE := make([]byte, l)
bufA := make([]byte, l)
e.x.FillBytes(bufE)
aa.x.FillBytes(bufA)
subtle.ConstantTimeCopy(v, bufE, bufA)
e.x.SetBytes(bufE)

e.y.FillBytes(bufE)
aa.y.FillBytes(bufA)
subtle.ConstantTimeCopy(v, bufE, bufA)
e.y.SetBytes(bufE)

return e
}

func (e *wElt) CSelect(v int, a Element, b Element) Element {
if !(v == 0 || v == 1) {
panic(ErrSelector)
}
aa, bb := e.cvtElt(a), e.cvtElt(b)
l := (e.wG.c.Params().BitSize + 7) / 8
bufE := make([]byte, l)
bufA := make([]byte, l)
bufB := make([]byte, l)

e.x.FillBytes(bufE)
aa.x.FillBytes(bufA)
bb.x.FillBytes(bufB)
for i := range bufE {
bufE[i] = byte(subtle.ConstantTimeSelect(v, int(bufA[i]), int(bufB[i])))
}
e.x.SetBytes(bufE)

e.y.FillBytes(bufE)
aa.y.FillBytes(bufA)
bb.y.FillBytes(bufB)
for i := range bufE {
bufE[i] = byte(subtle.ConstantTimeSelect(v, int(bufA[i]), int(bufB[i])))
}
e.y.SetBytes(bufE)

return e
}

func (e *wElt) Add(a, b Element) Element {
aa, bb := e.cvtElt(a), e.cvtElt(b)
e.x, e.y = e.c.Add(aa.x, aa.y, bb.x, bb.y)
Expand Down Expand Up @@ -244,6 +295,27 @@ func (s *wScl) Set(a Scalar) Scalar {
}

func (s *wScl) Copy() Scalar { return s.wG.zeroScalar().Set(s) }

func (s *wScl) CMov(v int, a Scalar) Scalar {
if !(v == 0 || v == 1) {
panic(ErrSelector)
}
aa := s.cvtScl(a)
subtle.ConstantTimeCopy(v, s.k, aa.k)
return s
}

func (s *wScl) CSelect(v int, a Scalar, b Scalar) Scalar {
if !(v == 0 || v == 1) {
panic(ErrSelector)
}
aa, bb := s.cvtScl(a), s.cvtScl(b)
for i := range s.k {
s.k[i] = byte(subtle.ConstantTimeSelect(v, int(aa.k[i]), int(bb.k[i])))
}
return s
}

func (s *wScl) Add(a, b Scalar) Scalar {
aa, bb := s.cvtScl(a), s.cvtScl(b)
r := new(big.Int)
Expand Down

0 comments on commit 10a0004

Please sign in to comment.