internal/scalar: address review comments

This commit is contained in:
Filippo Valsorda 2019-05-15 14:17:31 -04:00
parent 644c28a2d3
commit 013999a2d0
4 changed files with 33 additions and 31 deletions

View File

@ -9,6 +9,7 @@ package scalar
import ( import (
"crypto/subtle" "crypto/subtle"
"encoding/binary" "encoding/binary"
"errors"
) )
// A Scalar is an integer modulo // A Scalar is an integer modulo
@ -52,9 +53,12 @@ func (s *Scalar) Mul(x, y *Scalar) *Scalar {
return s return s
} }
// FromUniformBytes sets s = x mod l, where x should be 64 bytes of uniform // FromUniformBytes sets s to an uniformly distributed value given 64 uniformly
// randomness interpreted in little-endian. // distributed random bytes.
func (s *Scalar) FromUniformBytes(x []byte) *Scalar { func (s *Scalar) FromUniformBytes(x []byte) *Scalar {
if len(x) != 32 {
panic("scalar: invalid uniform input length")
}
var wideBytes [64]byte var wideBytes [64]byte
copy(wideBytes[:], x[:]) copy(wideBytes[:], x[:])
scReduce(s, &wideBytes) scReduce(s, &wideBytes)
@ -62,13 +66,17 @@ func (s *Scalar) FromUniformBytes(x []byte) *Scalar {
} }
// FromCanonicalBytes sets s = x, where x is a 32 bytes little-endian encoding // FromCanonicalBytes sets s = x, where x is a 32 bytes little-endian encoding
// of s, and returns whether x is a canonical encoding of s. // of s. If x is not a canonical encoding of s, FromCanonicalBytes returns an
func (s *Scalar) FromCanonicalBytes(x []byte) bool { // error and the receiver is unchanged.
func (s *Scalar) FromCanonicalBytes(x []byte) error {
if len(x) != 32 { if len(x) != 32 {
panic("invalid scalar length") panic("scalar: invalid scalar length")
}
if !scMinimal(x) {
return errors.New("invalid scalar encoding")
} }
copy(s[:], x) copy(s[:], x)
return scMinimal(s) return nil
} }
// reduce reduces s mod l returns it. // reduce reduces s mod l returns it.
@ -889,7 +897,7 @@ var order = [4]uint64{0x5812631a5cf5d3ed, 0x14def9dea2f79cd6, 0, 0x1000000000000
// scMinimal returns true if the given scalar is less than the order of the // scMinimal returns true if the given scalar is less than the order of the
// curve. // curve.
func scMinimal(sc *Scalar) bool { func scMinimal(sc []byte) bool {
for i := 3; ; i-- { for i := 3; ; i-- {
v := binary.LittleEndian.Uint64(sc[i*8:]) v := binary.LittleEndian.Uint64(sc[i*8:])
if v > order[i] { if v > order[i] {

View File

@ -18,11 +18,11 @@ var quickCheckConfig = &quick.Config{MaxCountScale: 1 << 10}
func TestFromBytesRoundTrip(t *testing.T) { func TestFromBytesRoundTrip(t *testing.T) {
f1 := func(in, out [32]byte, sc Scalar) bool { f1 := func(in, out [32]byte, sc Scalar) bool {
in[len(in)-1] &= (1 << 4) - 1 // Mask out top 4 bits for 252-bit numbers in[len(in)-1] &= (1 << 4) - 1 // Mask out top 4 bits for 252-bit numbers
if ok := sc.FromCanonicalBytes(in[:]); !ok { if err := sc.FromCanonicalBytes(in[:]); err != nil {
return false return false
} }
sc.Bytes(out[:0]) sc.Bytes(out[:0])
return bytes.Equal(in[:], out[:]) && scMinimal(&sc) return bytes.Equal(in[:], out[:]) && scMinimal(sc[:])
} }
if err := quick.Check(f1, nil); err != nil { if err := quick.Check(f1, nil); err != nil {
t.Errorf("failed bytes->scalar->bytes round-trip: %v", err) t.Errorf("failed bytes->scalar->bytes round-trip: %v", err)
@ -30,7 +30,7 @@ func TestFromBytesRoundTrip(t *testing.T) {
f2 := func(sc1, sc2 Scalar, out [32]byte) bool { f2 := func(sc1, sc2 Scalar, out [32]byte) bool {
sc1.Bytes(out[:0]) sc1.Bytes(out[:0])
if ok := sc2.FromCanonicalBytes(out[:]); !ok { if err := sc2.FromCanonicalBytes(out[:]); err != nil {
return false return false
} }
@ -48,17 +48,13 @@ func TestFromUniformBytes(t *testing.T) {
mod.Add(mod, new(big.Int).Lsh(big.NewInt(1), 252)) mod.Add(mod, new(big.Int).Lsh(big.NewInt(1), 252))
f := func(in [64]byte, sc Scalar) bool { f := func(in [64]byte, sc Scalar) bool {
sc.FromUniformBytes(in[:]) sc.FromUniformBytes(in[:])
if !scMinimal(&sc) { if !scMinimal(sc[:]) {
return false return false
} }
b := sc.Bytes(nil) b := sc.Bytes(nil)
for i := range b[:len(b)/2] { byteSwap(b) // convert to big endian for SetBytes
b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i]
}
scBig := new(big.Int).SetBytes(b) scBig := new(big.Int).SetBytes(b)
for i := range in[:len(in)/2] { byteSwap(in[:]) // convert to big endian for SetBytes
in[i], in[len(in)-i-1] = in[len(in)-i-1], in[i]
}
inBig := new(big.Int).SetBytes(in[:]) inBig := new(big.Int).SetBytes(in[:])
return inBig.Mod(inBig, mod).Cmp(scBig) == 0 return inBig.Mod(inBig, mod).Cmp(scBig) == 0
} }
@ -67,6 +63,12 @@ func TestFromUniformBytes(t *testing.T) {
} }
} }
func byteSwap(b []byte) {
for i := range b[:len(b)/2] {
b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i]
}
}
func TestMulDistributesOverAdd(t *testing.T) { func TestMulDistributesOverAdd(t *testing.T) {
mulDistributesOverAdd := func(x, y, z Scalar) bool { mulDistributesOverAdd := func(x, y, z Scalar) bool {
// Compute t1 = (x+y)*z // Compute t1 = (x+y)*z
@ -81,7 +83,7 @@ func TestMulDistributesOverAdd(t *testing.T) {
t3.Mul(&y, &z) t3.Mul(&y, &z)
t2.Add(&t2, &t3) t2.Add(&t2, &t3)
return t1.Equal(&t2) == 1 && scMinimal(&t1) && scMinimal(&t2) return t1.Equal(&t2) == 1 && scMinimal(t1[:]) && scMinimal(t2[:])
} }
if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil { if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil {

View File

@ -355,7 +355,7 @@ func (e *Element) Zero() *Element {
} }
// Base sets e to the canonical generator specified in // Base sets e to the canonical generator specified in
// draft-hdevalence-cfrg-ristretto-01, Section 3. // draft-hdevalence-cfrg-ristretto-01, Section 3, and returns e.
func (e *Element) Base() *Element { func (e *Element) Base() *Element {
e.r.Set(&edwards25519.B) e.r.Set(&edwards25519.B)
return e return e

View File

@ -5,14 +5,11 @@
package ristretto255 package ristretto255
import ( import (
"errors"
"github.com/gtank/ristretto255/internal/scalar" "github.com/gtank/ristretto255/internal/scalar"
) )
// A Scalar is an element of the ristretto255 scalar field, as specified in // A Scalar is an element of the ristretto255 scalar field, as specified in
// draft-hdevalence-cfrg-ristretto-01, Section 3.4. That is, an element of the // draft-hdevalence-cfrg-ristretto-01, Section 3.4. That is, an integer modulo
// group of order
// //
// l = 2^252 + 27742317777372353535851937790883648493 // l = 2^252 + 27742317777372353535851937790883648493
type Scalar struct { type Scalar struct {
@ -43,8 +40,8 @@ func (s *Scalar) Mul(x, y *Scalar) *Scalar {
return s return s
} }
// FromUniformBytes sets s = x mod l, where x should be 64 bytes of uniform // FromUniformBytes sets s to an uniformly distributed value given 64 uniformly
// randomness interpreted in little-endian. // distributed random bytes.
func (s *Scalar) FromUniformBytes(x []byte) *Scalar { func (s *Scalar) FromUniformBytes(x []byte) *Scalar {
s.s.FromUniformBytes(x) s.s.FromUniformBytes(x)
return s return s
@ -54,12 +51,7 @@ func (s *Scalar) FromUniformBytes(x []byte) *Scalar {
// not a canonical encoding of s, Decode returns an error and the receiver is // not a canonical encoding of s, Decode returns an error and the receiver is
// unchanged. // unchanged.
func (s *Scalar) Decode(x []byte) error { func (s *Scalar) Decode(x []byte) error {
var tmp scalar.Scalar return s.s.FromCanonicalBytes(x)
if !tmp.FromCanonicalBytes(x) {
return errors.New("invalid scalar encoding")
}
s.s = tmp
return nil
} }
// Encode appends a 32 bytes little-endian encoding of s to b. // Encode appends a 32 bytes little-endian encoding of s to b.