From ee040a35292ad79e74efa9567955e58a00704713 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Tue, 9 Feb 2021 19:30:28 +0100 Subject: [PATCH] ristretto255: align the API with filippo.io/edwards25519 Fixes #36 Fixes #40 --- ristretto255.go | 78 +++++++++++++++++++++++++++++++++++++------- ristretto255_test.go | 4 +-- scalar.go | 41 +++++++++++++++++++++-- 3 files changed, 107 insertions(+), 16 deletions(-) diff --git a/ristretto255.go b/ristretto255.go index 9ee68c0..03ad57c 100644 --- a/ristretto255.go +++ b/ristretto255.go @@ -42,8 +42,24 @@ type Element struct { } // NewElement returns a new Element set to the identity value. +// +// Deprecated: use NewIdentityElement. This API will be removed before v1.0.0. func NewElement() *Element { - return (&Element{}).Zero() + return NewIdentityElement() +} + +// NewIdentityElement returns a new Element set to the identity value. +func NewIdentityElement() *Element { + e := &Element{} + e.r.Zero() + return e +} + +// NewGeneratorElement returns a new Element set to the canonical generator. +func NewGeneratorElement() *Element { + e := &Element{} + e.r.Set(&edwards25519.B) + return e } // Set sets the value of e to x and returns e. @@ -72,9 +88,22 @@ func (e *Element) Equal(ee *Element) int { // FromUniformBytes maps the 64-byte slice b to e uniformly and // deterministically, and returns e. This can be used for hash-to-group // operations or to obtain a random element. +// +// Deprecated: use SetUniformBytes. This API will be removed before v1.0.0. func (e *Element) FromUniformBytes(b []byte) *Element { + if _, err := e.SetUniformBytes(b); err != nil { + panic(err.Error()) + } + return e +} + +// SetUniformBytes deterministically sets e to an uniformly distributed value +// given 64 uniformly distributed random bytes. +// +// This can be used for hash-to-group operations or to obtain a random element. +func (e *Element) SetUniformBytes(b []byte) (*Element, error) { if len(b) != 64 { - panic("ristretto255: FromUniformBytes: input is not 64 bytes long") + return nil, errors.New("ristretto255: SetUniformBytes input is not 64 bytes long") } f := &radix51.FieldElement{} @@ -87,7 +116,7 @@ func (e *Element) FromUniformBytes(b []byte) *Element { point2 := &Element{} mapToPoint(&point2.r, f) - return e.Add(point1, point2) + return e.Add(point1, point2), nil } // mapToPoint implements MAP from Section 3.2.4 of draft-hdevalence-cfrg-ristretto-00. @@ -153,7 +182,20 @@ func mapToPoint(out *edwards25519.ProjP3, t *radix51.FieldElement) { // Encode appends the 32 bytes canonical encoding of e to b // and returns the result. +// +// Deprecated: use Bytes. This API will be removed before v1.0.0. func (e *Element) Encode(b []byte) []byte { + return e.bytes(b) +} + +// Bytes returns the 32 bytes canonical encoding of e. +func (e *Element) Bytes() []byte { + // Bytes is outlined to let the allocation happen on the stack of the caller. + b := make([]byte, 0, 32) + return e.bytes(b) +} + +func (e *Element) bytes(b []byte) []byte { tmp := &radix51.FieldElement{} // u1 = (z0 + y0) * (z0 - y0) @@ -215,9 +257,19 @@ var errInvalidEncoding = errors.New("invalid Ristretto encoding") // Decode sets e to the decoded value of in. If in is not a 32 byte canonical // encoding, Decode returns an error, and the receiver is unchanged. +// +// Deprecated: use SetCanonicalBytes. This API will be removed before v1.0.0. func (e *Element) Decode(in []byte) error { + _, err := e.SetCanonicalBytes(in) + return err +} + +// SetCanonicalBytes sets e to the decoded value of in. If in is not a canonical +// encoding of s, SetCanonicalBytes returns nil and an error and the receiver is +// unchanged. +func (e *Element) SetCanonicalBytes(in []byte) (*Element, error) { if len(in) != 32 { - return errInvalidEncoding + return nil, errInvalidEncoding } // First, interpret the string as an integer s in little-endian representation. @@ -227,12 +279,12 @@ func (e *Element) Decode(in []byte) error { // If the resulting value is >= p, decoding fails. var buf [32]byte if !bytes.Equal(s.Bytes(buf[:0]), in) { - return errInvalidEncoding + return nil, errInvalidEncoding } // If IS_NEGATIVE(s) returns TRUE, decoding fails. if s.IsNegative() == 1 { - return errInvalidEncoding + return nil, errInvalidEncoding } // ss = s^2 @@ -276,12 +328,12 @@ func (e *Element) Decode(in []byte) error { // If was_square is FALSE, or IS_NEGATIVE(t) returns TRUE, or y = 0, decoding fails. if wasSquare == 0 || out.T.IsNegative() == 1 || out.Y.Equal(radix51.Zero) == 1 { - return errInvalidEncoding + return nil, errInvalidEncoding } // Otherwise, return the internal representation in extended coordinates (x, y, 1, t). e.r.Set(&out) - return nil + return e, nil } // ScalarBaseMult sets e = s * B, where B is the canonical generator, and returns e. @@ -358,16 +410,18 @@ func (e *Element) Negate(p *Element) *Element { } // Zero sets e to the identity element of the group, and returns e. +// +// Deprecated: use NewIdentityElement and Set. This API will be removed before v1.0.0. func (e *Element) Zero() *Element { - e.r.Zero() - return e + return e.Set(NewIdentityElement()) } // Base sets e to the canonical generator specified in // draft-hdevalence-cfrg-ristretto-01, Section 3, and returns e. +// +// Deprecated: use NewGeneratorElement and Set. This API will be removed before v1.0.0. func (e *Element) Base() *Element { - e.r.Set(&edwards25519.B) - return e + return e.Set(NewGeneratorElement()) } // MarshalText implements encoding/TextMarshaler interface diff --git a/ristretto255_test.go b/ristretto255_test.go index 735a756..18a6834 100644 --- a/ristretto255_test.go +++ b/ristretto255_test.go @@ -307,8 +307,8 @@ func TestMarshalElement(t *testing.T) { func TestElementSet(t *testing.T) { // Test this, because the internal point type being hard-copyable isn't part of the spec. - el1 := NewElement().Zero() - el2 := NewElement().Base() + el1 := NewIdentityElement() + el2 := NewGeneratorElement() if el1.Equal(el2) == 1 { t.Error("shouldn't be the same") diff --git a/scalar.go b/scalar.go index 1dd9369..63a810d 100644 --- a/scalar.go +++ b/scalar.go @@ -6,6 +6,7 @@ package ristretto255 import ( "encoding/base64" + "errors" "github.com/gtank/ristretto255/internal/scalar" ) @@ -63,23 +64,59 @@ func (s *Scalar) Invert(x *Scalar) *Scalar { // FromUniformBytes sets s to an uniformly distributed value given 64 uniformly // distributed random bytes. +// +// Deprecated: use SetUniformBytes. This API will be removed before v1.0.0. func (s *Scalar) FromUniformBytes(x []byte) *Scalar { - s.s.FromUniformBytes(x) + if _, err := s.SetUniformBytes(x); err != nil { + panic(err.Error()) + } return s } +// SetUniformBytes sets s to an uniformly distributed value given 64 uniformly +// distributed random bytes. +func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) { + if len(x) != 64 { + return nil, errors.New("ristretto255: SetUniformBytes input is not 64 bytes long") + } + s.s.FromUniformBytes(x) + return s, nil +} + // Decode sets s = x, where x is a 32 bytes little-endian encoding of s. If x is // not a canonical encoding of s, Decode returns an error and the receiver is // unchanged. +// +// Deprecated: use SetCanonicalBytes. This API will be removed before v1.0.0. func (s *Scalar) Decode(x []byte) error { - return s.s.FromCanonicalBytes(x) + _, err := s.SetCanonicalBytes(x) + return err +} + +// SetCanonicalBytes sets s = x, where x is a 32 bytes little-endian encoding of +// s. If x is not a canonical encoding of s, SetCanonicalBytes returns nil and +// an error and the receiver is unchanged. +func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) { + if err := s.s.FromCanonicalBytes(x); err != nil { + return nil, err + } + return s, nil } // Encode appends a 32 bytes little-endian encoding of s to b. +// +// Deprecated: use Bytes. This API will be removed before v1.0.0. func (s *Scalar) Encode(b []byte) []byte { return s.s.Bytes(b) } +// Bytes returns the 32 bytes little-endian encoding of s. +func (s *Scalar) Bytes() []byte { + // Bytes is small, so the allocation happens on the stack of the caller. + b := make([]byte, 0, 32) + return s.s.Bytes(b) +} + // Equal returns 1 if v and u are equal, and 0 otherwise. func (s *Scalar) Equal(u *Scalar) int { return s.s.Equal(&u.s)