From 013999a2d0ae532da5d5c865b51cdf562ec7a3f9 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 15 May 2019 14:17:31 -0400 Subject: [PATCH] internal/scalar: address review comments --- internal/scalar/scalar.go | 22 +++++++++++++++------- internal/scalar/scalar_test.go | 24 +++++++++++++----------- ristretto255.go | 2 +- scalar.go | 16 ++++------------ 4 files changed, 33 insertions(+), 31 deletions(-) diff --git a/internal/scalar/scalar.go b/internal/scalar/scalar.go index ec5776f..ad201f2 100644 --- a/internal/scalar/scalar.go +++ b/internal/scalar/scalar.go @@ -9,6 +9,7 @@ package scalar import ( "crypto/subtle" "encoding/binary" + "errors" ) // A Scalar is an integer modulo @@ -52,9 +53,12 @@ func (s *Scalar) Mul(x, y *Scalar) *Scalar { return s } -// FromUniformBytes sets s = x mod l, where x should be 64 bytes of uniform -// randomness interpreted in little-endian. +// FromUniformBytes sets s to an uniformly distributed value given 64 uniformly +// distributed random bytes. func (s *Scalar) FromUniformBytes(x []byte) *Scalar { + if len(x) != 32 { + panic("scalar: invalid uniform input length") + } var wideBytes [64]byte copy(wideBytes[:], x[:]) scReduce(s, &wideBytes) @@ -62,13 +66,17 @@ func (s *Scalar) FromUniformBytes(x []byte) *Scalar { } // FromCanonicalBytes sets s = x, where x is a 32 bytes little-endian encoding -// of s, and returns whether x is a canonical encoding of s. -func (s *Scalar) FromCanonicalBytes(x []byte) bool { +// of s. If x is not a canonical encoding of s, FromCanonicalBytes returns an +// error and the receiver is unchanged. +func (s *Scalar) FromCanonicalBytes(x []byte) error { if len(x) != 32 { - panic("invalid scalar length") + panic("scalar: invalid scalar length") + } + if !scMinimal(x) { + return errors.New("invalid scalar encoding") } copy(s[:], x) - return scMinimal(s) + return nil } // reduce reduces s mod l returns it. @@ -889,7 +897,7 @@ var order = [4]uint64{0x5812631a5cf5d3ed, 0x14def9dea2f79cd6, 0, 0x1000000000000 // scMinimal returns true if the given scalar is less than the order of the // curve. -func scMinimal(sc *Scalar) bool { +func scMinimal(sc []byte) bool { for i := 3; ; i-- { v := binary.LittleEndian.Uint64(sc[i*8:]) if v > order[i] { diff --git a/internal/scalar/scalar_test.go b/internal/scalar/scalar_test.go index bbaae48..df3648b 100644 --- a/internal/scalar/scalar_test.go +++ b/internal/scalar/scalar_test.go @@ -18,11 +18,11 @@ var quickCheckConfig = &quick.Config{MaxCountScale: 1 << 10} func TestFromBytesRoundTrip(t *testing.T) { f1 := func(in, out [32]byte, sc Scalar) bool { in[len(in)-1] &= (1 << 4) - 1 // Mask out top 4 bits for 252-bit numbers - if ok := sc.FromCanonicalBytes(in[:]); !ok { + if err := sc.FromCanonicalBytes(in[:]); err != nil { return false } sc.Bytes(out[:0]) - return bytes.Equal(in[:], out[:]) && scMinimal(&sc) + return bytes.Equal(in[:], out[:]) && scMinimal(sc[:]) } if err := quick.Check(f1, nil); err != nil { t.Errorf("failed bytes->scalar->bytes round-trip: %v", err) @@ -30,7 +30,7 @@ func TestFromBytesRoundTrip(t *testing.T) { f2 := func(sc1, sc2 Scalar, out [32]byte) bool { sc1.Bytes(out[:0]) - if ok := sc2.FromCanonicalBytes(out[:]); !ok { + if err := sc2.FromCanonicalBytes(out[:]); err != nil { return false } @@ -48,17 +48,13 @@ func TestFromUniformBytes(t *testing.T) { mod.Add(mod, new(big.Int).Lsh(big.NewInt(1), 252)) f := func(in [64]byte, sc Scalar) bool { sc.FromUniformBytes(in[:]) - if !scMinimal(&sc) { + if !scMinimal(sc[:]) { return false } b := sc.Bytes(nil) - for i := range b[:len(b)/2] { - b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i] - } + byteSwap(b) // convert to big endian for SetBytes scBig := new(big.Int).SetBytes(b) - for i := range in[:len(in)/2] { - in[i], in[len(in)-i-1] = in[len(in)-i-1], in[i] - } + byteSwap(in[:]) // convert to big endian for SetBytes inBig := new(big.Int).SetBytes(in[:]) return inBig.Mod(inBig, mod).Cmp(scBig) == 0 } @@ -67,6 +63,12 @@ func TestFromUniformBytes(t *testing.T) { } } +func byteSwap(b []byte) { + for i := range b[:len(b)/2] { + b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i] + } +} + func TestMulDistributesOverAdd(t *testing.T) { mulDistributesOverAdd := func(x, y, z Scalar) bool { // Compute t1 = (x+y)*z @@ -81,7 +83,7 @@ func TestMulDistributesOverAdd(t *testing.T) { t3.Mul(&y, &z) t2.Add(&t2, &t3) - return t1.Equal(&t2) == 1 && scMinimal(&t1) && scMinimal(&t2) + return t1.Equal(&t2) == 1 && scMinimal(t1[:]) && scMinimal(t2[:]) } if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil { diff --git a/ristretto255.go b/ristretto255.go index 734e43a..d8f9535 100644 --- a/ristretto255.go +++ b/ristretto255.go @@ -355,7 +355,7 @@ func (e *Element) Zero() *Element { } // Base sets e to the canonical generator specified in -// draft-hdevalence-cfrg-ristretto-01, Section 3. +// draft-hdevalence-cfrg-ristretto-01, Section 3, and returns e. func (e *Element) Base() *Element { e.r.Set(&edwards25519.B) return e diff --git a/scalar.go b/scalar.go index 98810af..71b7f08 100644 --- a/scalar.go +++ b/scalar.go @@ -5,14 +5,11 @@ package ristretto255 import ( - "errors" - "github.com/gtank/ristretto255/internal/scalar" ) // A Scalar is an element of the ristretto255 scalar field, as specified in -// draft-hdevalence-cfrg-ristretto-01, Section 3.4. That is, an element of the -// group of order +// draft-hdevalence-cfrg-ristretto-01, Section 3.4. That is, an integer modulo // // l = 2^252 + 27742317777372353535851937790883648493 type Scalar struct { @@ -43,8 +40,8 @@ func (s *Scalar) Mul(x, y *Scalar) *Scalar { return s } -// FromUniformBytes sets s = x mod l, where x should be 64 bytes of uniform -// randomness interpreted in little-endian. +// FromUniformBytes sets s to an uniformly distributed value given 64 uniformly +// distributed random bytes. func (s *Scalar) FromUniformBytes(x []byte) *Scalar { s.s.FromUniformBytes(x) return s @@ -54,12 +51,7 @@ func (s *Scalar) FromUniformBytes(x []byte) *Scalar { // not a canonical encoding of s, Decode returns an error and the receiver is // unchanged. func (s *Scalar) Decode(x []byte) error { - var tmp scalar.Scalar - if !tmp.FromCanonicalBytes(x) { - return errors.New("invalid scalar encoding") - } - s.s = tmp - return nil + return s.s.FromCanonicalBytes(x) } // Encode appends a 32 bytes little-endian encoding of s to b.