diff --git a/internal/scalar/scalar.go b/internal/scalar/scalar.go index cf3b23a..9ed46cb 100644 --- a/internal/scalar/scalar.go +++ b/internal/scalar/scalar.go @@ -52,15 +52,25 @@ func (s *Scalar) Mul(x, y *Scalar) *Scalar { return s } -// FromBytes sets s = x mod l, where x should be 64 bytes of uniform randomness -// interpreted in little-endian. -func (s *Scalar) FromBytes(x []byte) *Scalar { +// FromUniformBytes sets s = x mod l, where x should be 64 bytes of uniform +// randomness interpreted in little-endian. +func (s *Scalar) FromUniformBytes(x []byte) *Scalar { var wideBytes [64]byte copy(wideBytes[:], x[:]) scReduce(s, &wideBytes) return s } +// FromCanonicalBytes sets s = x, where x is a 32 bytes little-endian encoding +// of s, and returns whether x is a canonical encoding of s. +func (s *Scalar) FromCanonicalBytes(x []byte) bool { + if len(x) != 32 { + panic("invalid scalar length") + } + copy(s[:], x) + return scMinimal(s) +} + // reduce reduces s mod l returns it. func (s *Scalar) reduce() *Scalar { var wideBytes [64]byte @@ -80,11 +90,6 @@ func (s *Scalar) Bytes(b []byte) []byte { return res } -// IsCanonical returns true if s < l, false otherwise. -func (s *Scalar) IsCanonical() bool { - return scMinimal(s) -} - // Equal returns 1 if v and u are equal, and 0 otherwise. func (s *Scalar) Equal(u *Scalar) int { var su, ss [32]byte diff --git a/internal/scalar/scalar_test.go b/internal/scalar/scalar_test.go index 09724e1..bbaae48 100644 --- a/internal/scalar/scalar_test.go +++ b/internal/scalar/scalar_test.go @@ -6,6 +6,7 @@ package scalar import ( "bytes" + "math/big" "testing" "testing/quick" ) @@ -17,9 +18,11 @@ var quickCheckConfig = &quick.Config{MaxCountScale: 1 << 10} func TestFromBytesRoundTrip(t *testing.T) { f1 := func(in, out [32]byte, sc Scalar) bool { in[len(in)-1] &= (1 << 4) - 1 // Mask out top 4 bits for 252-bit numbers - sc.FromBytes(in[:]) + if ok := sc.FromCanonicalBytes(in[:]); !ok { + return false + } sc.Bytes(out[:0]) - return bytes.Equal(in[:], out[:]) && sc.IsCanonical() + return bytes.Equal(in[:], out[:]) && scMinimal(&sc) } if err := quick.Check(f1, nil); err != nil { t.Errorf("failed bytes->scalar->bytes round-trip: %v", err) @@ -27,7 +30,9 @@ func TestFromBytesRoundTrip(t *testing.T) { f2 := func(sc1, sc2 Scalar, out [32]byte) bool { sc1.Bytes(out[:0]) - sc2.FromBytes(out[:]) + if ok := sc2.FromCanonicalBytes(out[:]); !ok { + return false + } sc1.reduce() sc2.reduce() @@ -38,6 +43,30 @@ func TestFromBytesRoundTrip(t *testing.T) { } } +func TestFromUniformBytes(t *testing.T) { + mod, _ := new(big.Int).SetString("27742317777372353535851937790883648493", 10) + mod.Add(mod, new(big.Int).Lsh(big.NewInt(1), 252)) + f := func(in [64]byte, sc Scalar) bool { + sc.FromUniformBytes(in[:]) + if !scMinimal(&sc) { + return false + } + b := sc.Bytes(nil) + for i := range b[:len(b)/2] { + b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i] + } + scBig := new(big.Int).SetBytes(b) + for i := range in[:len(in)/2] { + in[i], in[len(in)-i-1] = in[len(in)-i-1], in[i] + } + inBig := new(big.Int).SetBytes(in[:]) + return inBig.Mod(inBig, mod).Cmp(scBig) == 0 + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} + func TestMulDistributesOverAdd(t *testing.T) { mulDistributesOverAdd := func(x, y, z Scalar) bool { // Compute t1 = (x+y)*z @@ -52,7 +81,7 @@ func TestMulDistributesOverAdd(t *testing.T) { t3.Mul(&y, &z) t2.Add(&t2, &t3) - return t1.Equal(&t2) == 1 && t1.IsCanonical() && t2.IsCanonical() + return t1.Equal(&t2) == 1 && scMinimal(&t1) && scMinimal(&t2) } if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil {