From 0e06c64ad711874523cb2a69636f0646ab3ecb19 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Sun, 12 May 2019 04:43:38 -0400 Subject: [PATCH] ristretto255: expose scalar multiplication APIs The names of the ScalarMults were picked to match elliptic.Curve. The Scalar type is re-exposed as an opaque type, with an API that matches the Element one. --- internal/edwards25519/edwards25519.go | 8 ++ internal/edwards25519/edwards25519_test.go | 12 -- ristretto255.go | 122 ++++++++++++++++----- scalar.go | 79 +++++++++++++ 4 files changed, 183 insertions(+), 38 deletions(-) create mode 100644 scalar.go diff --git a/internal/edwards25519/edwards25519.go b/internal/edwards25519/edwards25519.go index 91a54d7..a0808bc 100644 --- a/internal/edwards25519/edwards25519.go +++ b/internal/edwards25519/edwards25519.go @@ -46,6 +46,14 @@ type AffineCached struct { YplusX, YminusX, T2d radix51.FieldElement } +// B is the Ed25519 basepoint. +var B = ProjP3{ + X: radix51.FieldElement([5]uint64{1738742601995546, 1146398526822698, 2070867633025821, 562264141797630, 587772402128613}), + Y: radix51.FieldElement([5]uint64{1801439850948184, 1351079888211148, 450359962737049, 900719925474099, 1801439850948198}), + Z: radix51.FieldElement([5]uint64{1, 0, 0, 0, 0}), + T: radix51.FieldElement([5]uint64{1841354044333475, 16398895984059, 755974180946558, 900171276175154, 1821297809914039}), +} + // Constructors. func (v *ProjP1xP1) Zero() *ProjP1xP1 { diff --git a/internal/edwards25519/edwards25519_test.go b/internal/edwards25519/edwards25519_test.go index 0a281ce..d8b1de6 100644 --- a/internal/edwards25519/edwards25519_test.go +++ b/internal/edwards25519/edwards25519_test.go @@ -6,18 +6,6 @@ package edwards25519 import ( "testing" - - "github.com/gtank/ristretto255/internal/radix51" -) - -var ( - // The Ed25519 basepoint. - B = ProjP3{ - X: radix51.FieldElement([5]uint64{1738742601995546, 1146398526822698, 2070867633025821, 562264141797630, 587772402128613}), - Y: radix51.FieldElement([5]uint64{1801439850948184, 1351079888211148, 450359962737049, 900719925474099, 1801439850948198}), - Z: radix51.FieldElement([5]uint64{1, 0, 0, 0, 0}), - T: radix51.FieldElement([5]uint64{1841354044333475, 16398895984059, 755974180946558, 900171276175154, 1821297809914039}), - } ) func TestAddSubNegOnBasePoint(t *testing.T) { diff --git a/ristretto255.go b/ristretto255.go index 073d461..a9461f3 100644 --- a/ristretto255.go +++ b/ristretto255.go @@ -3,8 +3,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package ristretto255 implements the ristretto255 prime-order group as -// specified in draft-hdevalence-cfrg-ristretto-00. +// Package ristretto255 implements the group of prime order +// +// 2**252 + 27742317777372353535851937790883648493 +// +// as specified in draft-hdevalence-cfrg-ristretto-01. +// +// All operations are constant time unless otherwise specified. package ristretto255 import ( @@ -13,6 +18,7 @@ import ( "github.com/gtank/ristretto255/internal/edwards25519" "github.com/gtank/ristretto255/internal/radix51" + "github.com/gtank/ristretto255/internal/scalar" ) var ( @@ -33,12 +39,13 @@ var ( // Element is an element of the ristretto255 prime-order group. // // The zero value of Element is not valid, but can be used as the receiver for -// any operation. +// any setting operation. type Element struct { r edwards25519.ProjP3 } // Equal returns 1 if e is equivalent to ee, and 0 otherwise. +// // Note that Elements must not be compared in any other way. func (e *Element) Equal(ee *Element) int { var f0, f1 radix51.FieldElement @@ -55,9 +62,9 @@ func (e *Element) Equal(ee *Element) int { } // FromUniformBytes maps the 64-byte slice b to e uniformly and -// deterministically. This can be used for hash-to-group operations or to obtain -// a random element. -func (e *Element) FromUniformBytes(b []byte) { +// deterministically, and returns e. This can be used for hash-to-group +// operations or to obtain a random element. +func (e *Element) FromUniformBytes(b []byte) *Element { if len(b) != 64 { panic("ristretto255: FromUniformBytes: input is not 64 bytes long") } @@ -72,7 +79,7 @@ func (e *Element) FromUniformBytes(b []byte) { point2 := &Element{} mapToPoint(&point2.r, f) - e.Add(point1, point2) + return e.Add(point1, point2) } // mapToPoint implements MAP from Section 3.2.4 of draft-hdevalence-cfrg-ristretto-00. @@ -136,7 +143,8 @@ func mapToPoint(out *edwards25519.ProjP3, t *radix51.FieldElement) { out.T.Mul(w0, w2) } -// Encode appends the canonical encoding of e to b and returns the result. +// Encode appends the 32 bytes canonical encoding of e to b +// and returns the result. func (e *Element) Encode(b []byte) []byte { tmp := &radix51.FieldElement{} @@ -195,8 +203,8 @@ func (e *Element) Encode(b []byte) []byte { return s.Bytes(b) } -// Decode sets e to the decoded value of in. If in is not a canonical encoding, -// Decode returns an error, and the receiver is unchanged. +// Decode sets e to the decoded value of in. If in is not a 32 byte canonical +// encoding, Decode returns an error, and the receiver is unchanged. func (e *Element) Decode(in []byte) error { if len(in) != 32 { return errInvalidEncoding @@ -266,26 +274,88 @@ func (e *Element) Decode(in []byte) error { return nil } -// Add sets v = p + q, and returns v. -func (v *Element) Add(p, q *Element) *Element { - v.r.Add(&p.r, &q.r) - return v +// ScalarBaseMult sets e = s * B, where B is the canonical generator, and returns e. +func (e *Element) ScalarBaseMult(s *Scalar) *Element { + e.r.BasepointMul(&s.s) + return e } -// Sub sets v = p - q, and returns v. -func (v *Element) Sub(p, q *Element) *Element { - v.r.Sub(&p.r, &q.r) - return v +// ScalarMult sets e = s * p, and returns e. +func (e *Element) ScalarMult(s *Scalar, p *Element) *Element { + e.r.ScalarMul(&s.s, &p.r) + return e } -// Neg sets v = -p, and returns v. -func (v *Element) Neg(p *Element) *Element { - v.r.Neg(&p.r) - return v +// MultiScalarMult sets e = sum(s[i] * p[i]), and returns e. +// +// Execution time depends only on the lengths of the two slices, which must match. +func (e *Element) MultiScalarMult(s []*Scalar, p []*Element) *Element { + if len(p) != len(s) { + panic("ristretto255: MultiScalarMult invoked with mismatched slice lengths") + } + points := make([]*edwards25519.ProjP3, len(p)) + scalars := make([]scalar.Scalar, len(s)) + for i := range s { + points[i] = &p[i].r + scalars[i] = s[i].s + } + e.r.MultiscalarMul(scalars, points) + return e } -// Zero sets v to the identity element of the group, and returns v. -func (v *Element) Zero() *Element { - v.r.Zero() - return v +// VarTimeMultiScalarMult sets e = sum(s[i] * p[i]), and returns e. +// +// Execution time depends on the inputs. +func (e *Element) VarTimeMultiScalarMult(s []*Scalar, p []*Element) *Element { + if len(p) != len(s) { + panic("ristretto255: MultiScalarMult invoked with mismatched slice lengths") + } + points := make([]*edwards25519.ProjP3, len(p)) + scalars := make([]scalar.Scalar, len(s)) + for i := range s { + points[i] = &p[i].r + scalars[i] = s[i].s + } + e.r.VartimeMultiscalarMul(scalars, points) + return e +} + +// VarTimeDoubleScalarBaseMult sets e = a * A + b * B, where B is the canonical +// generator, and returns e. +// +// Execution time depends on the inputs. +func (e *Element) VarTimeDoubleScalarBaseMult(a *Scalar, A *Element, b *Scalar) *Element { + e.r.VartimeDoubleBaseMul(&a.s, &A.r, &b.s) + return e +} + +// Add sets e = p + q, and returns e. +func (e *Element) Add(p, q *Element) *Element { + e.r.Add(&p.r, &q.r) + return e +} + +// Sub sets e = p - q, and returns e. +func (e *Element) Sub(p, q *Element) *Element { + e.r.Sub(&p.r, &q.r) + return e +} + +// Neg sets e = -p, and returns e. +func (e *Element) Neg(p *Element) *Element { + e.r.Neg(&p.r) + return e +} + +// Zero sets e to the identity element of the group, and returns e. +func (e *Element) Zero() *Element { + e.r.Zero() + return e +} + +// Base sets e to the canonical generator specified in +// draft-hdevalence-cfrg-ristretto-01, Section 3. +func (e *Element) Base() *Element { + e.r.Set(&edwards25519.B) + return e } diff --git a/scalar.go b/scalar.go new file mode 100644 index 0000000..98810af --- /dev/null +++ b/scalar.go @@ -0,0 +1,79 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ristretto255 + +import ( + "errors" + + "github.com/gtank/ristretto255/internal/scalar" +) + +// A Scalar is an element of the ristretto255 scalar field, as specified in +// draft-hdevalence-cfrg-ristretto-01, Section 3.4. That is, an element of the +// group of order +// +// l = 2^252 + 27742317777372353535851937790883648493 +type Scalar struct { + s scalar.Scalar +} + +// Add sets s = x + y mod l and returns s. +func (s *Scalar) Add(x, y *Scalar) *Scalar { + s.s.Add(&x.s, &y.s) + return s +} + +// Sub sets s = x - y mod l and returns s. +func (s *Scalar) Sub(x, y *Scalar) *Scalar { + s.s.Sub(&x.s, &y.s) + return s +} + +// Neg sets s = -x mod l and returns s. +func (s *Scalar) Neg(x *Scalar) *Scalar { + s.s.Neg(&x.s) + return s +} + +// Mul sets s = x * y mod l and returns s. +func (s *Scalar) Mul(x, y *Scalar) *Scalar { + s.s.Mul(&x.s, &y.s) + return s +} + +// FromUniformBytes sets s = x mod l, where x should be 64 bytes of uniform +// randomness interpreted in little-endian. +func (s *Scalar) FromUniformBytes(x []byte) *Scalar { + s.s.FromUniformBytes(x) + return s +} + +// Decode sets s = x, where x is a 32 bytes little-endian encoding of s. If x is +// not a canonical encoding of s, Decode returns an error and the receiver is +// unchanged. +func (s *Scalar) Decode(x []byte) error { + var tmp scalar.Scalar + if !tmp.FromCanonicalBytes(x) { + return errors.New("invalid scalar encoding") + } + s.s = tmp + return nil +} + +// Encode appends a 32 bytes little-endian encoding of s to b. +func (s *Scalar) Encode(b []byte) []byte { + return s.s.Bytes(b) +} + +// Equal returns 1 if v and u are equal, and 0 otherwise. +func (s *Scalar) Equal(u *Scalar) int { + return s.s.Equal(&u.s) +} + +// Zero sets s = 0 and returns s. +func (s *Scalar) Zero() *Scalar { + s.s = scalar.Scalar{} + return s +}