internal: adapt Scalar to fiat's internal types

This commit is contained in:
George Tankersley 2021-02-19 16:59:03 -08:00
parent efdfd0f0c1
commit a83cb39e47
3 changed files with 91 additions and 110 deletions

View File

@ -18,7 +18,8 @@ var (
quickCheckConfig = &quick.Config{MaxCountScale: 1 << 6}
// a random scalar generated using dalek.
dalekScalar = scalar.Scalar([32]byte{219, 106, 114, 9, 174, 249, 155, 89, 69, 203, 201, 93, 92, 116, 234, 187, 78, 115, 103, 172, 182, 98, 62, 103, 187, 136, 13, 100, 248, 110, 12, 4})
dalekScalar scalar.Scalar
// the above, times the Ed25519 basepoint.
dalekScalarBasepoint = ProjP3{
X: radix51.FieldElement([5]uint64{778774234987948, 1589187156384239, 1213330452914652, 186161118421127, 2186284806803213}),
@ -28,6 +29,10 @@ var (
}
)
func init() {
_ = dalekScalar.FromCanonicalBytes([]byte{219, 106, 114, 9, 174, 249, 155, 89, 69, 203, 201, 93, 92, 116, 234, 187, 78, 115, 103, 172, 182, 98, 62, 103, 187, 136, 13, 100, 248, 110, 12, 4})
}
func TestScalarMulSmallScalars(t *testing.T) {
var z scalar.Scalar
var p, check ProjP3
@ -37,7 +42,8 @@ func TestScalarMulSmallScalars(t *testing.T) {
t.Error("0*B != 0")
}
z = scalar.Scalar([32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// z = montgomery domain 1
z = scalar.Scalar([4]uint64{0xd6ec31748d98951d, 0xc6ef5bf4737dcf70, 0xfffffffffffffffe, 0xfffffffffffffff})
p.ScalarMul(&z, &B)
check.Set(&B)
if check.Equal(&p) != 1 {
@ -76,12 +82,6 @@ func TestVartimeDoubleBaseMulVsDalek(t *testing.T) {
func TestScalarMulDistributesOverAdd(t *testing.T) {
scalarMulDistributesOverAdd := func(x, y scalar.Scalar) bool {
// The quickcheck generation strategy chooses a random
// 32-byte array, but we require that the high bit is
// unset. FIXME: make Scalar opaque. Until then,
// mask the high bits:
x[31] &= 127
y[31] &= 127
var z scalar.Scalar
z.Add(&x, &y)
var p, q, r, check ProjP3
@ -128,7 +128,7 @@ func TestBasepointTableGeneration(t *testing.T) {
func TestScalarMulMatchesBasepointMul(t *testing.T) {
scalarMulMatchesBasepointMul := func(x scalar.Scalar) bool {
// FIXME opaque scalars
x[31] &= 127
// x[31] &= 127
var p, q ProjP3
p.ScalarMul(&x, &B)
q.BasepointMul(&x)
@ -143,9 +143,9 @@ func TestScalarMulMatchesBasepointMul(t *testing.T) {
func TestMultiScalarMulMatchesBasepointMul(t *testing.T) {
multiScalarMulMatchesBasepointMul := func(x, y, z scalar.Scalar) bool {
// FIXME opaque scalars
x[31] &= 127
y[31] &= 127
z[31] &= 127
// x[31] &= 127
// y[31] &= 127
// z[31] &= 127
var p, q1, q2, q3, check ProjP3
p.MultiscalarMul([]scalar.Scalar{x, y, z}, []*ProjP3{&B, &B, &B})
@ -175,9 +175,6 @@ func TestBasepointNafTableGeneration(t *testing.T) {
func TestVartimeDoubleBaseMulMatchesBasepointMul(t *testing.T) {
vartimeDoubleBaseMulMatchesBasepointMul := func(x, y scalar.Scalar) bool {
// FIXME opaque scalars
x[31] &= 127
y[31] &= 127
var p, q1, q2, check ProjP3
p.VartimeDoubleBaseMul(&x, &B, &y)
@ -197,10 +194,6 @@ func TestVartimeDoubleBaseMulMatchesBasepointMul(t *testing.T) {
func TestVartimeMultiScalarMulMatchesBasepointMul(t *testing.T) {
vartimeMultiScalarMulMatchesBasepointMul := func(x, y, z scalar.Scalar) bool {
// FIXME opaque scalars
x[31] &= 127
y[31] &= 127
z[31] &= 127
var p, q1, q2, q3, check ProjP3
p.VartimeMultiscalarMul([]scalar.Scalar{x, y, z}, []*ProjP3{&B, &B, &B})

View File

@ -10,90 +10,40 @@ import (
"crypto/subtle"
"encoding/binary"
"errors"
mathrand "math/rand"
"reflect"
)
// A Scalar is an integer modulo
// l = 2^252 + 27742317777372353535851937790883648493,
// here represented as an opaque little-endian byte string.
type Scalar [32]byte
// A Scalar is an integer modulo l = 2^252 + 27742317777372353535851937790883648493.
// Internally, this implementation keeps the scalar in the Montgomery domain.
type Scalar [4]uint64
var (
scZero = Scalar([32]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
scOne = Scalar([32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// sage: l = GF(2**252 + 27742317777372353535851937790883648493)
// sage: l(-1).lift().digits(256)
scMinusOne = Scalar([32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16})
// 1 in the Montgomery domain
scOne = Scalar([4]uint64{0xd6ec31748d98951d, 0xc6ef5bf4737dcf70, 0xfffffffffffffffe, 0xfffffffffffffff})
)
// Add sets s = x + y mod l and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
var montyX [4]uint64
var montyY [4]uint64
fiat_sc255_from_bytes(&montyX, (*[32]byte)(x))
fiat_sc255_from_bytes(&montyY, (*[32]byte)(y))
fiat_sc255_to_montgomery(&montyX, &montyX)
fiat_sc255_to_montgomery(&montyY, &montyY)
var result [4]uint64
fiat_sc255_add(&result, &montyX, &montyY)
fiat_sc255_from_montgomery(&result, &result)
fiat_sc255_to_bytes((*[32]byte)(s), &result)
fiat_sc255_add((*[4]uint64)(s), (*[4]uint64)(x), (*[4]uint64)(y))
return s
}
// Sub sets s = x - y mod l and returns s.
func (s *Scalar) Sub(x, y *Scalar) *Scalar {
var montyX [4]uint64
var montyY [4]uint64
fiat_sc255_from_bytes(&montyX, (*[32]byte)(x))
fiat_sc255_from_bytes(&montyY, (*[32]byte)(y))
fiat_sc255_to_montgomery(&montyX, &montyX)
fiat_sc255_to_montgomery(&montyY, &montyY)
var result [4]uint64
fiat_sc255_sub(&result, &montyX, &montyY)
fiat_sc255_from_montgomery(&result, &result)
fiat_sc255_to_bytes((*[32]byte)(s), &result)
fiat_sc255_sub((*[4]uint64)(s), (*[4]uint64)(x), (*[4]uint64)(y))
return s
}
// Neg sets s = -x mod l and returns s.
func (s *Scalar) Neg(x *Scalar) *Scalar {
var montyX [4]uint64
fiat_sc255_from_bytes(&montyX, (*[32]byte)(x))
fiat_sc255_to_montgomery(&montyX, &montyX)
var result [4]uint64
fiat_sc255_opp(&result, &montyX)
fiat_sc255_from_montgomery(&result, &result)
fiat_sc255_to_bytes((*[32]byte)(s), &result)
fiat_sc255_opp((*[4]uint64)(s), (*[4]uint64)(x))
return s
}
// Mul sets s = x * y mod l and returns s.
func (s *Scalar) Mul(x, y *Scalar) *Scalar {
// scMulAdd(s, x, y, &scZero)
var montyX [4]uint64
var montyY [4]uint64
fiat_sc255_from_bytes(&montyX, (*[32]byte)(x))
fiat_sc255_from_bytes(&montyY, (*[32]byte)(y))
fiat_sc255_to_montgomery(&montyX, &montyX)
fiat_sc255_to_montgomery(&montyY, &montyY)
var result [4]uint64
fiat_sc255_mul(&result, &montyX, &montyY)
fiat_sc255_from_montgomery(&result, &result)
fiat_sc255_to_bytes((*[32]byte)(s), &result)
fiat_sc255_mul((*[4]uint64)(s), (*[4]uint64)(x), (*[4]uint64)(y))
return s
}
@ -103,9 +53,17 @@ func (s *Scalar) FromUniformBytes(x []byte) *Scalar {
if len(x) != 64 {
panic("scalar: invalid uniform input length")
}
var wideBytes [64]byte
copy(wideBytes[:], x[:])
scReduce(s, &wideBytes)
// TODO: scReduce is deprecated but retained here for consistent behavior
var reduced [32]byte
scReduce(&reduced, &wideBytes)
fiat_sc255_from_bytes((*[4]uint64)(s), &reduced)
fiat_sc255_to_montgomery((*[4]uint64)(s), (*[4]uint64)(s))
return s
}
@ -116,28 +74,31 @@ func (s *Scalar) FromCanonicalBytes(x []byte) error {
if len(x) != 32 {
panic("scalar: invalid scalar length")
}
if !scMinimal(x) {
return errors.New("invalid scalar encoding")
}
copy(s[:], x)
return nil
}
// reduce reduces s mod l returns it.
func (s *Scalar) reduce() *Scalar {
var wideBytes [64]byte
copy(wideBytes[:], s[:])
scReduce(s, &wideBytes)
return s
var b [32]byte
var in [4]uint64
copy(b[:], x)
fiat_sc255_from_bytes(&in, &b)
fiat_sc255_to_montgomery((*[4]uint64)(s), &in)
return nil
}
// Bytes appends a 32 bytes little-endian encoding of s to b.
func (s *Scalar) Bytes(b []byte) []byte {
t := *s
t.reduce()
var reduced [4]uint64
var repr [32]byte
fiat_sc255_from_montgomery(&reduced, (*[4]uint64)(s))
fiat_sc255_to_bytes(&repr, &reduced)
res, out := sliceForAppend(b, 32)
copy(out, t[:])
copy(out, repr[:])
return res
}
@ -185,7 +146,7 @@ func load4(in []byte) int64 {
// Output:
// s[0]+256*s[1]+...+256^31*s[31] = s mod l
// where l = 2^252 + 27742317777372353535851937790883648493.
func scReduce(out *Scalar, s *[64]byte) {
func scReduce(out *[32]byte, s *[64]byte) {
s0 := 2097151 & load3(s[:])
s1 := 2097151 & (load4(s[2:]) >> 5)
s2 := 2097151 & (load3(s[5:]) >> 2)
@ -506,9 +467,14 @@ func scReduce(out *Scalar, s *[64]byte) {
// order is the order of Curve25519 in little-endian form.
var order = [4]uint64{0x5812631a5cf5d3ed, 0x14def9dea2f79cd6, 0, 0x1000000000000000}
// scMinimal returns true if the given scalar is less than the order of the
// curve.
// scMinimal returns true if the given little-endian byte
// representation of a scalar NOT in the Montgomery domain
// is less than the order of the group.
func scMinimal(sc []byte) bool {
if len(sc) != 32 {
return false
}
for i := 3; ; i-- {
v := binary.LittleEndian.Uint64(sc[i*8:])
if v > order[i] {
@ -525,10 +491,12 @@ func scMinimal(sc []byte) bool {
// NonAdjacentForm computes a width-w non-adjacent form for this scalar.
func (s *Scalar) NonAdjacentForm(w uint) [256]int8 {
byteRepr := s.Bytes(nil)
// This implementation is adapted from the one
// in curve25519-dalek and is documented there:
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
if s[31] > 127 {
if byteRepr[31] > 127 {
panic("scalar has high bit set illegally")
}
if w < 2 {
@ -541,7 +509,7 @@ func (s *Scalar) NonAdjacentForm(w uint) [256]int8 {
var digits [5]uint64
for i := 0; i < 4; i++ {
digits[i] = binary.LittleEndian.Uint64(s[i*8:])
digits[i] = binary.LittleEndian.Uint64(byteRepr[i*8:])
}
width := uint64(1 << w)
@ -589,7 +557,9 @@ func (s *Scalar) NonAdjacentForm(w uint) [256]int8 {
}
func (s *Scalar) SignedRadix16() [64]int8 {
if s[31] > 127 {
byteRepr := s.Bytes(nil)
if byteRepr[31] > 127 {
panic("scalar has high bit set illegally")
}
@ -597,8 +567,8 @@ func (s *Scalar) SignedRadix16() [64]int8 {
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(s[i] & 15)
digits[2*i+1] = int8((s[i] >> 4) & 15)
digits[2*i] = int8(byteRepr[i] & 15)
digits[2*i+1] = int8((byteRepr[i] >> 4) & 15)
}
// Recenter coefficients:
@ -721,3 +691,17 @@ func (s *Scalar) Inv(t *Scalar) *Scalar {
return s
}
func generateScalar(mrand *mathrand.Rand) Scalar {
var r [64]byte
mrand.Read(r[:])
s := (&Scalar{}).FromUniformBytes(r[:])
return *s
}
// Generate generates an arbitrary valid Scalar for quickcheck tests.
// It is here because it needs to be visible for other packages (currently internal/edwards25519)
// that want scalars in their tests.
func (sc Scalar) Generate(mrand *mathrand.Rand, size int) reflect.Value {
return reflect.ValueOf(generateScalar(mrand))
}

View File

@ -13,7 +13,7 @@ import (
// quickCheckConfig will make each quickcheck test run (1024 * -quickchecks)
// times. The default value of -quickchecks is 100.
var quickCheckConfig = &quick.Config{MaxCountScale: 1 << 10}
var quickCheckConfig = &quick.Config{MaxCountScale: 1 << 12}
func TestFromBytesRoundTrip(t *testing.T) {
f1 := func(in, out [32]byte, sc Scalar) bool {
@ -22,7 +22,7 @@ func TestFromBytesRoundTrip(t *testing.T) {
return false
}
sc.Bytes(out[:0])
return bytes.Equal(in[:], out[:]) && scMinimal(sc[:])
return bytes.Equal(in[:], out[:]) && scMinimal(out[:])
}
if err := quick.Check(f1, nil); err != nil {
t.Errorf("failed bytes->scalar->bytes round-trip: %v", err)
@ -34,9 +34,7 @@ func TestFromBytesRoundTrip(t *testing.T) {
return false
}
sc1.reduce()
sc2.reduce()
return sc1 == sc2
return bytes.Equal(sc1.Bytes(nil), sc2.Bytes(nil))
}
if err := quick.Check(f2, nil); err != nil {
t.Errorf("failed scalar->bytes->scalar round-trip: %v", err)
@ -48,7 +46,7 @@ func TestFromUniformBytes(t *testing.T) {
mod.Add(mod, new(big.Int).Lsh(big.NewInt(1), 252))
f := func(in [64]byte, sc Scalar) bool {
sc.FromUniformBytes(in[:])
if !scMinimal(sc[:]) {
if !scMinimal(sc.Bytes(nil)) {
return false
}
b := sc.Bytes(nil)
@ -83,7 +81,7 @@ func TestMulDistributesOverAdd(t *testing.T) {
t3.Mul(&y, &z)
t2.Add(&t2, &t3)
return t1.Equal(&t2) == 1 && scMinimal(t1[:]) && scMinimal(t2[:])
return t1.Equal(&t2) == 1 && scMinimal(t1.Bytes(nil)) && scMinimal(t2.Bytes(nil))
}
if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil {
@ -92,12 +90,18 @@ func TestMulDistributesOverAdd(t *testing.T) {
}
func TestNonAdjacentForm(t *testing.T) {
s := Scalar([32]byte{
in := [32]byte{
0x1a, 0x0e, 0x97, 0x8a, 0x90, 0xf6, 0x62, 0x2d,
0x37, 0x47, 0x02, 0x3f, 0x8a, 0xd8, 0x26, 0x4d,
0xa7, 0x58, 0xaa, 0x1b, 0x88, 0xe0, 0x40, 0xd1,
0x58, 0x9e, 0x7b, 0x7f, 0x23, 0x76, 0xef, 0x09,
})
}
s := new(Scalar)
err := s.FromCanonicalBytes(in[:])
if err != nil {
t.Fatal(err)
}
expectedNaf := [256]int8{
0, 13, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, -9, 0, 0, 0, 0, -11, 0, 0, 0, 0, 3, 0, 0, 0, 0, 1,
0, 0, 0, 0, 9, 0, 0, 0, 0, -5, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 11, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0,