ristretto255: expose scalar multiplication APIs

The names of the ScalarMults were picked to match elliptic.Curve.

The Scalar type is re-exposed as an opaque type, with an API that
matches the Element one.
This commit is contained in:
Filippo Valsorda 2019-05-12 04:43:38 -04:00
parent 97912109c3
commit 0e06c64ad7
4 changed files with 183 additions and 38 deletions

View File

@ -46,6 +46,14 @@ type AffineCached struct {
YplusX, YminusX, T2d radix51.FieldElement
}
// B is the Ed25519 basepoint.
var B = ProjP3{
X: radix51.FieldElement([5]uint64{1738742601995546, 1146398526822698, 2070867633025821, 562264141797630, 587772402128613}),
Y: radix51.FieldElement([5]uint64{1801439850948184, 1351079888211148, 450359962737049, 900719925474099, 1801439850948198}),
Z: radix51.FieldElement([5]uint64{1, 0, 0, 0, 0}),
T: radix51.FieldElement([5]uint64{1841354044333475, 16398895984059, 755974180946558, 900171276175154, 1821297809914039}),
}
// Constructors.
func (v *ProjP1xP1) Zero() *ProjP1xP1 {

View File

@ -6,18 +6,6 @@ package edwards25519
import (
"testing"
"github.com/gtank/ristretto255/internal/radix51"
)
var (
// The Ed25519 basepoint.
B = ProjP3{
X: radix51.FieldElement([5]uint64{1738742601995546, 1146398526822698, 2070867633025821, 562264141797630, 587772402128613}),
Y: radix51.FieldElement([5]uint64{1801439850948184, 1351079888211148, 450359962737049, 900719925474099, 1801439850948198}),
Z: radix51.FieldElement([5]uint64{1, 0, 0, 0, 0}),
T: radix51.FieldElement([5]uint64{1841354044333475, 16398895984059, 755974180946558, 900171276175154, 1821297809914039}),
}
)
func TestAddSubNegOnBasePoint(t *testing.T) {

View File

@ -3,8 +3,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ristretto255 implements the ristretto255 prime-order group as
// specified in draft-hdevalence-cfrg-ristretto-00.
// Package ristretto255 implements the group of prime order
//
// 2**252 + 27742317777372353535851937790883648493
//
// as specified in draft-hdevalence-cfrg-ristretto-01.
//
// All operations are constant time unless otherwise specified.
package ristretto255
import (
@ -13,6 +18,7 @@ import (
"github.com/gtank/ristretto255/internal/edwards25519"
"github.com/gtank/ristretto255/internal/radix51"
"github.com/gtank/ristretto255/internal/scalar"
)
var (
@ -33,12 +39,13 @@ var (
// Element is an element of the ristretto255 prime-order group.
//
// The zero value of Element is not valid, but can be used as the receiver for
// any operation.
// any setting operation.
type Element struct {
r edwards25519.ProjP3
}
// Equal returns 1 if e is equivalent to ee, and 0 otherwise.
//
// Note that Elements must not be compared in any other way.
func (e *Element) Equal(ee *Element) int {
var f0, f1 radix51.FieldElement
@ -55,9 +62,9 @@ func (e *Element) Equal(ee *Element) int {
}
// FromUniformBytes maps the 64-byte slice b to e uniformly and
// deterministically. This can be used for hash-to-group operations or to obtain
// a random element.
func (e *Element) FromUniformBytes(b []byte) {
// deterministically, and returns e. This can be used for hash-to-group
// operations or to obtain a random element.
func (e *Element) FromUniformBytes(b []byte) *Element {
if len(b) != 64 {
panic("ristretto255: FromUniformBytes: input is not 64 bytes long")
}
@ -72,7 +79,7 @@ func (e *Element) FromUniformBytes(b []byte) {
point2 := &Element{}
mapToPoint(&point2.r, f)
e.Add(point1, point2)
return e.Add(point1, point2)
}
// mapToPoint implements MAP from Section 3.2.4 of draft-hdevalence-cfrg-ristretto-00.
@ -136,7 +143,8 @@ func mapToPoint(out *edwards25519.ProjP3, t *radix51.FieldElement) {
out.T.Mul(w0, w2)
}
// Encode appends the canonical encoding of e to b and returns the result.
// Encode appends the 32 bytes canonical encoding of e to b
// and returns the result.
func (e *Element) Encode(b []byte) []byte {
tmp := &radix51.FieldElement{}
@ -195,8 +203,8 @@ func (e *Element) Encode(b []byte) []byte {
return s.Bytes(b)
}
// Decode sets e to the decoded value of in. If in is not a canonical encoding,
// Decode returns an error, and the receiver is unchanged.
// Decode sets e to the decoded value of in. If in is not a 32 byte canonical
// encoding, Decode returns an error, and the receiver is unchanged.
func (e *Element) Decode(in []byte) error {
if len(in) != 32 {
return errInvalidEncoding
@ -266,26 +274,88 @@ func (e *Element) Decode(in []byte) error {
return nil
}
// Add sets v = p + q, and returns v.
func (v *Element) Add(p, q *Element) *Element {
v.r.Add(&p.r, &q.r)
return v
// ScalarBaseMult sets e = s * B, where B is the canonical generator, and returns e.
func (e *Element) ScalarBaseMult(s *Scalar) *Element {
e.r.BasepointMul(&s.s)
return e
}
// Sub sets v = p - q, and returns v.
func (v *Element) Sub(p, q *Element) *Element {
v.r.Sub(&p.r, &q.r)
return v
// ScalarMult sets e = s * p, and returns e.
func (e *Element) ScalarMult(s *Scalar, p *Element) *Element {
e.r.ScalarMul(&s.s, &p.r)
return e
}
// Neg sets v = -p, and returns v.
func (v *Element) Neg(p *Element) *Element {
v.r.Neg(&p.r)
return v
// MultiScalarMult sets e = sum(s[i] * p[i]), and returns e.
//
// Execution time depends only on the lengths of the two slices, which must match.
func (e *Element) MultiScalarMult(s []*Scalar, p []*Element) *Element {
if len(p) != len(s) {
panic("ristretto255: MultiScalarMult invoked with mismatched slice lengths")
}
points := make([]*edwards25519.ProjP3, len(p))
scalars := make([]scalar.Scalar, len(s))
for i := range s {
points[i] = &p[i].r
scalars[i] = s[i].s
}
e.r.MultiscalarMul(scalars, points)
return e
}
// Zero sets v to the identity element of the group, and returns v.
func (v *Element) Zero() *Element {
v.r.Zero()
return v
// VarTimeMultiScalarMult sets e = sum(s[i] * p[i]), and returns e.
//
// Execution time depends on the inputs.
func (e *Element) VarTimeMultiScalarMult(s []*Scalar, p []*Element) *Element {
if len(p) != len(s) {
panic("ristretto255: MultiScalarMult invoked with mismatched slice lengths")
}
points := make([]*edwards25519.ProjP3, len(p))
scalars := make([]scalar.Scalar, len(s))
for i := range s {
points[i] = &p[i].r
scalars[i] = s[i].s
}
e.r.VartimeMultiscalarMul(scalars, points)
return e
}
// VarTimeDoubleScalarBaseMult sets e = a * A + b * B, where B is the canonical
// generator, and returns e.
//
// Execution time depends on the inputs.
func (e *Element) VarTimeDoubleScalarBaseMult(a *Scalar, A *Element, b *Scalar) *Element {
e.r.VartimeDoubleBaseMul(&a.s, &A.r, &b.s)
return e
}
// Add sets e = p + q, and returns e.
func (e *Element) Add(p, q *Element) *Element {
e.r.Add(&p.r, &q.r)
return e
}
// Sub sets e = p - q, and returns e.
func (e *Element) Sub(p, q *Element) *Element {
e.r.Sub(&p.r, &q.r)
return e
}
// Neg sets e = -p, and returns e.
func (e *Element) Neg(p *Element) *Element {
e.r.Neg(&p.r)
return e
}
// Zero sets e to the identity element of the group, and returns e.
func (e *Element) Zero() *Element {
e.r.Zero()
return e
}
// Base sets e to the canonical generator specified in
// draft-hdevalence-cfrg-ristretto-01, Section 3.
func (e *Element) Base() *Element {
e.r.Set(&edwards25519.B)
return e
}

79
scalar.go Normal file
View File

@ -0,0 +1,79 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ristretto255
import (
"errors"
"github.com/gtank/ristretto255/internal/scalar"
)
// A Scalar is an element of the ristretto255 scalar field, as specified in
// draft-hdevalence-cfrg-ristretto-01, Section 3.4. That is, an element of the
// group of order
//
// l = 2^252 + 27742317777372353535851937790883648493
type Scalar struct {
s scalar.Scalar
}
// Add sets s = x + y mod l and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
s.s.Add(&x.s, &y.s)
return s
}
// Sub sets s = x - y mod l and returns s.
func (s *Scalar) Sub(x, y *Scalar) *Scalar {
s.s.Sub(&x.s, &y.s)
return s
}
// Neg sets s = -x mod l and returns s.
func (s *Scalar) Neg(x *Scalar) *Scalar {
s.s.Neg(&x.s)
return s
}
// Mul sets s = x * y mod l and returns s.
func (s *Scalar) Mul(x, y *Scalar) *Scalar {
s.s.Mul(&x.s, &y.s)
return s
}
// FromUniformBytes sets s = x mod l, where x should be 64 bytes of uniform
// randomness interpreted in little-endian.
func (s *Scalar) FromUniformBytes(x []byte) *Scalar {
s.s.FromUniformBytes(x)
return s
}
// Decode sets s = x, where x is a 32 bytes little-endian encoding of s. If x is
// not a canonical encoding of s, Decode returns an error and the receiver is
// unchanged.
func (s *Scalar) Decode(x []byte) error {
var tmp scalar.Scalar
if !tmp.FromCanonicalBytes(x) {
return errors.New("invalid scalar encoding")
}
s.s = tmp
return nil
}
// Encode appends a 32 bytes little-endian encoding of s to b.
func (s *Scalar) Encode(b []byte) []byte {
return s.s.Bytes(b)
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (s *Scalar) Equal(u *Scalar) int {
return s.s.Equal(&u.s)
}
// Zero sets s = 0 and returns s.
func (s *Scalar) Zero() *Scalar {
s.s = scalar.Scalar{}
return s
}