mirror of https://github.com/gtank/ristretto255
internal/scalar: replace FromBytes/IsCanonical with FromUniformBytes/FromCanonicalBytes
This commit is contained in:
parent
4baac9a766
commit
116bee5e3d
|
@ -52,15 +52,25 @@ func (s *Scalar) Mul(x, y *Scalar) *Scalar {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// FromBytes sets s = x mod l, where x should be 64 bytes of uniform randomness
|
// FromUniformBytes sets s = x mod l, where x should be 64 bytes of uniform
|
||||||
// interpreted in little-endian.
|
// randomness interpreted in little-endian.
|
||||||
func (s *Scalar) FromBytes(x []byte) *Scalar {
|
func (s *Scalar) FromUniformBytes(x []byte) *Scalar {
|
||||||
var wideBytes [64]byte
|
var wideBytes [64]byte
|
||||||
copy(wideBytes[:], x[:])
|
copy(wideBytes[:], x[:])
|
||||||
scReduce(s, &wideBytes)
|
scReduce(s, &wideBytes)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FromCanonicalBytes sets s = x, where x is a 32 bytes little-endian encoding
|
||||||
|
// of s, and returns whether x is a canonical encoding of s.
|
||||||
|
func (s *Scalar) FromCanonicalBytes(x []byte) bool {
|
||||||
|
if len(x) != 32 {
|
||||||
|
panic("invalid scalar length")
|
||||||
|
}
|
||||||
|
copy(s[:], x)
|
||||||
|
return scMinimal(s)
|
||||||
|
}
|
||||||
|
|
||||||
// reduce reduces s mod l returns it.
|
// reduce reduces s mod l returns it.
|
||||||
func (s *Scalar) reduce() *Scalar {
|
func (s *Scalar) reduce() *Scalar {
|
||||||
var wideBytes [64]byte
|
var wideBytes [64]byte
|
||||||
|
@ -80,11 +90,6 @@ func (s *Scalar) Bytes(b []byte) []byte {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsCanonical returns true if s < l, false otherwise.
|
|
||||||
func (s *Scalar) IsCanonical() bool {
|
|
||||||
return scMinimal(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Equal returns 1 if v and u are equal, and 0 otherwise.
|
// Equal returns 1 if v and u are equal, and 0 otherwise.
|
||||||
func (s *Scalar) Equal(u *Scalar) int {
|
func (s *Scalar) Equal(u *Scalar) int {
|
||||||
var su, ss [32]byte
|
var su, ss [32]byte
|
||||||
|
|
|
@ -6,6 +6,7 @@ package scalar
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
"testing/quick"
|
"testing/quick"
|
||||||
)
|
)
|
||||||
|
@ -17,9 +18,11 @@ var quickCheckConfig = &quick.Config{MaxCountScale: 1 << 10}
|
||||||
func TestFromBytesRoundTrip(t *testing.T) {
|
func TestFromBytesRoundTrip(t *testing.T) {
|
||||||
f1 := func(in, out [32]byte, sc Scalar) bool {
|
f1 := func(in, out [32]byte, sc Scalar) bool {
|
||||||
in[len(in)-1] &= (1 << 4) - 1 // Mask out top 4 bits for 252-bit numbers
|
in[len(in)-1] &= (1 << 4) - 1 // Mask out top 4 bits for 252-bit numbers
|
||||||
sc.FromBytes(in[:])
|
if ok := sc.FromCanonicalBytes(in[:]); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
sc.Bytes(out[:0])
|
sc.Bytes(out[:0])
|
||||||
return bytes.Equal(in[:], out[:]) && sc.IsCanonical()
|
return bytes.Equal(in[:], out[:]) && scMinimal(&sc)
|
||||||
}
|
}
|
||||||
if err := quick.Check(f1, nil); err != nil {
|
if err := quick.Check(f1, nil); err != nil {
|
||||||
t.Errorf("failed bytes->scalar->bytes round-trip: %v", err)
|
t.Errorf("failed bytes->scalar->bytes round-trip: %v", err)
|
||||||
|
@ -27,7 +30,9 @@ func TestFromBytesRoundTrip(t *testing.T) {
|
||||||
|
|
||||||
f2 := func(sc1, sc2 Scalar, out [32]byte) bool {
|
f2 := func(sc1, sc2 Scalar, out [32]byte) bool {
|
||||||
sc1.Bytes(out[:0])
|
sc1.Bytes(out[:0])
|
||||||
sc2.FromBytes(out[:])
|
if ok := sc2.FromCanonicalBytes(out[:]); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
sc1.reduce()
|
sc1.reduce()
|
||||||
sc2.reduce()
|
sc2.reduce()
|
||||||
|
@ -38,6 +43,30 @@ func TestFromBytesRoundTrip(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFromUniformBytes(t *testing.T) {
|
||||||
|
mod, _ := new(big.Int).SetString("27742317777372353535851937790883648493", 10)
|
||||||
|
mod.Add(mod, new(big.Int).Lsh(big.NewInt(1), 252))
|
||||||
|
f := func(in [64]byte, sc Scalar) bool {
|
||||||
|
sc.FromUniformBytes(in[:])
|
||||||
|
if !scMinimal(&sc) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
b := sc.Bytes(nil)
|
||||||
|
for i := range b[:len(b)/2] {
|
||||||
|
b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i]
|
||||||
|
}
|
||||||
|
scBig := new(big.Int).SetBytes(b)
|
||||||
|
for i := range in[:len(in)/2] {
|
||||||
|
in[i], in[len(in)-i-1] = in[len(in)-i-1], in[i]
|
||||||
|
}
|
||||||
|
inBig := new(big.Int).SetBytes(in[:])
|
||||||
|
return inBig.Mod(inBig, mod).Cmp(scBig) == 0
|
||||||
|
}
|
||||||
|
if err := quick.Check(f, nil); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMulDistributesOverAdd(t *testing.T) {
|
func TestMulDistributesOverAdd(t *testing.T) {
|
||||||
mulDistributesOverAdd := func(x, y, z Scalar) bool {
|
mulDistributesOverAdd := func(x, y, z Scalar) bool {
|
||||||
// Compute t1 = (x+y)*z
|
// Compute t1 = (x+y)*z
|
||||||
|
@ -52,7 +81,7 @@ func TestMulDistributesOverAdd(t *testing.T) {
|
||||||
t3.Mul(&y, &z)
|
t3.Mul(&y, &z)
|
||||||
t2.Add(&t2, &t3)
|
t2.Add(&t2, &t3)
|
||||||
|
|
||||||
return t1.Equal(&t2) == 1 && t1.IsCanonical() && t2.IsCanonical()
|
return t1.Equal(&t2) == 1 && scMinimal(&t1) && scMinimal(&t2)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil {
|
if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil {
|
||||||
|
|
Loading…
Reference in New Issue