ristretto255: replace backend with filippo.io/edwards25519

This commit is contained in:
Filippo Valsorda 2021-06-05 12:44:30 +02:00 committed by Filippo Valsorda
parent ee040a3529
commit 6bb3979858
25 changed files with 255 additions and 4353 deletions

111
fe.go
View File

@ -1,111 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Copyright 2019 George Tankersley. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ristretto255
import (
"math/big"
"github.com/gtank/ristretto255/internal/radix51"
)
// fePow22523 sets out to z^((p-5)/8). (p-5)/8 is 2^252-3.
func fePow22523(out, z *radix51.FieldElement) *radix51.FieldElement {
// Refactored from golang.org/x/crypto/ed25519/internal/edwards25519.
var t0, t1, t2 radix51.FieldElement
t0.Square(z)
for i := 1; i < 1; i++ {
t0.Square(&t0)
}
t1.Square(&t0)
for i := 1; i < 2; i++ {
t1.Square(&t1)
}
t1.Mul(z, &t1)
t0.Mul(&t0, &t1)
t0.Square(&t0)
for i := 1; i < 1; i++ {
t0.Square(&t0)
}
t0.Mul(&t1, &t0)
t1.Square(&t0)
for i := 1; i < 5; i++ {
t1.Square(&t1)
}
t0.Mul(&t1, &t0)
t1.Square(&t0)
for i := 1; i < 10; i++ {
t1.Square(&t1)
}
t1.Mul(&t1, &t0)
t2.Square(&t1)
for i := 1; i < 20; i++ {
t2.Square(&t2)
}
t1.Mul(&t2, &t1)
t1.Square(&t1)
for i := 1; i < 10; i++ {
t1.Square(&t1)
}
t0.Mul(&t1, &t0)
t1.Square(&t0)
for i := 1; i < 50; i++ {
t1.Square(&t1)
}
t1.Mul(&t1, &t0)
t2.Square(&t1)
for i := 1; i < 100; i++ {
t2.Square(&t2)
}
t1.Mul(&t2, &t1)
t1.Square(&t1)
for i := 1; i < 50; i++ {
t1.Square(&t1)
}
t0.Mul(&t1, &t0)
t0.Square(&t0)
for i := 1; i < 2; i++ {
t0.Square(&t0)
}
return out.Mul(&t0, z)
}
// feSqrtRatio sets r to the square root of the ratio of u and v, according to
// Section 3.1.3 of draft-hdevalence-cfrg-ristretto-00.
func feSqrtRatio(r, u, v *radix51.FieldElement) (wasSquare int) {
var a, b radix51.FieldElement
v3 := a.Mul(a.Square(v), v) // v^3 = v^2 * v
v7 := b.Mul(b.Square(v3), v) // v^7 = (v^3)^2 * v
// r = (u * v3) * (u * v7)^((p-5)/8)
uv3 := a.Mul(u, v3) // (u * v3)
uv7 := b.Mul(u, v7) // (u * v7)
r.Mul(uv3, fePow22523(r, uv7))
check := a.Mul(v, a.Square(r)) // check = v * r^2
uNeg := b.Neg(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(uNeg.Mul(uNeg, sqrtM1))
rPrime := b.Mul(r, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
r.Select(rPrime, r, flippedSignSqrt|flippedSignSqrtI)
r.Abs(r) // Choose the nonnegative square root.
return correctSignSqrt | flippedSignSqrt
}
func fieldElementFromDecimal(s string) *radix51.FieldElement {
n, ok := new(big.Int).SetString(s, 10)
if !ok {
panic("ristretto255: not a valid decimal: " + s)
}
return new(radix51.FieldElement).FromBig(n)
}

4
go.mod
View File

@ -1,3 +1,5 @@
module github.com/gtank/ristretto255
go 1.11
go 1.17
require filippo.io/edwards25519 v1.0.0-rc.1

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
filippo.io/edwards25519 v1.0.0-rc.1 h1:m0VOOB23frXZvAOK44usCgLWvtsxIoMCTBGJZlpmGfU=
filippo.io/edwards25519 v1.0.0-rc.1/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=

View File

@ -1,328 +0,0 @@
// Copyright (c) 2017 George Tankersley. All rights reserved.
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package edwards25519 implements group logic for the twisted Edwards curve
//
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
//
// This is better known as the Edwards curve equivalent to curve25519, and is
// the curve used by the Ed25519 signature scheme.
package edwards25519
import (
"github.com/gtank/ristretto255/internal/radix51"
)
// D is a constant in the curve equation.
var D = &radix51.FieldElement{929955233495203, 466365720129213,
1662059464998953, 2033849074728123, 1442794654840575}
var d2 = new(radix51.FieldElement).Add(D, D)
// Point types.
// TODO: write documentation
// TODO: rename (T,X,Y,Z) to (W0,W1,W2,W3) for P2 and P3 models?
// https://doc-internal.dalek.rs/curve25519_dalek/backend/serial/curve_models/index.html
type ProjP1xP1 struct {
X, Y, Z, T radix51.FieldElement
}
type ProjP2 struct {
X, Y, Z radix51.FieldElement
}
type ProjP3 struct {
X, Y, Z, T radix51.FieldElement
}
type ProjCached struct {
YplusX, YminusX, Z, T2d radix51.FieldElement
}
type AffineCached struct {
YplusX, YminusX, T2d radix51.FieldElement
}
// B is the Ed25519 basepoint.
var B = ProjP3{
X: radix51.FieldElement([5]uint64{1738742601995546, 1146398526822698, 2070867633025821, 562264141797630, 587772402128613}),
Y: radix51.FieldElement([5]uint64{1801439850948184, 1351079888211148, 450359962737049, 900719925474099, 1801439850948198}),
Z: radix51.FieldElement([5]uint64{1, 0, 0, 0, 0}),
T: radix51.FieldElement([5]uint64{1841354044333475, 16398895984059, 755974180946558, 900171276175154, 1821297809914039}),
}
// Constructors.
func (v *ProjP1xP1) Zero() *ProjP1xP1 {
v.X.Zero()
v.Y.One()
v.Z.One()
v.T.One()
return v
}
func (v *ProjP2) Zero() *ProjP2 {
v.X.Zero()
v.Y.One()
v.Z.One()
return v
}
func (v *ProjP3) Zero() *ProjP3 {
v.X.Zero()
v.Y.One()
v.Z.One()
v.T.Zero()
return v
}
func (v *ProjCached) Zero() *ProjCached {
v.YplusX.One()
v.YminusX.One()
v.Z.One()
v.T2d.Zero()
return v
}
func (v *AffineCached) Zero() *AffineCached {
v.YplusX.One()
v.YminusX.One()
v.T2d.Zero()
return v
}
// Assignments.
func (v *ProjP3) Set(u *ProjP3) *ProjP3 {
*v = *u
return v
}
// Conversions.
func (v *ProjP2) FromP1xP1(p *ProjP1xP1) *ProjP2 {
v.X.Mul(&p.X, &p.T)
v.Y.Mul(&p.Y, &p.Z)
v.Z.Mul(&p.Z, &p.T)
return v
}
func (v *ProjP2) FromP3(p *ProjP3) *ProjP2 {
v.X.Set(&p.X)
v.Y.Set(&p.Y)
v.Z.Set(&p.Z)
return v
}
func (v *ProjP3) FromP1xP1(p *ProjP1xP1) *ProjP3 {
v.X.Mul(&p.X, &p.T)
v.Y.Mul(&p.Y, &p.Z)
v.Z.Mul(&p.Z, &p.T)
v.T.Mul(&p.X, &p.Y)
return v
}
func (v *ProjP3) FromP2(p *ProjP2) *ProjP3 {
v.X.Mul(&p.X, &p.Z)
v.Y.Mul(&p.Y, &p.Z)
v.Z.Square(&p.Z)
v.T.Mul(&p.X, &p.Y)
return v
}
func (v *ProjCached) FromP3(p *ProjP3) *ProjCached {
v.YplusX.Add(&p.Y, &p.X)
v.YminusX.Sub(&p.Y, &p.X)
v.Z.Set(&p.Z)
v.T2d.Mul(&p.T, d2)
return v
}
func (v *AffineCached) FromP3(p *ProjP3) *AffineCached {
v.YplusX.Add(&p.Y, &p.X)
v.YminusX.Sub(&p.Y, &p.X)
v.T2d.Mul(&p.T, d2)
var invZ radix51.FieldElement
invZ.Invert(&p.Z)
v.YplusX.Mul(&v.YplusX, &invZ)
v.YminusX.Mul(&v.YminusX, &invZ)
v.T2d.Mul(&v.T2d, &invZ)
return v
}
// (Re)addition and subtraction.
func (v *ProjP3) Add(p, q *ProjP3) *ProjP3 {
result := ProjP1xP1{}
qCached := ProjCached{}
qCached.FromP3(q)
result.Add(p, &qCached)
v.FromP1xP1(&result)
return v
}
func (v *ProjP3) Sub(p, q *ProjP3) *ProjP3 {
result := ProjP1xP1{}
qCached := ProjCached{}
qCached.FromP3(q)
result.Sub(p, &qCached)
v.FromP1xP1(&result)
return v
}
func (v *ProjP1xP1) Add(p *ProjP3, q *ProjCached) *ProjP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 radix51.FieldElement
YplusX.Add(&p.Y, &p.X)
YminusX.Sub(&p.Y, &p.X)
PP.Mul(&YplusX, &q.YplusX)
MM.Mul(&YminusX, &q.YminusX)
TT2d.Mul(&p.T, &q.T2d)
ZZ2.Mul(&p.Z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Sub(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&ZZ2, &TT2d)
v.T.Sub(&ZZ2, &TT2d)
return v
}
func (v *ProjP1xP1) Sub(p *ProjP3, q *ProjCached) *ProjP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 radix51.FieldElement
YplusX.Add(&p.Y, &p.X)
YminusX.Sub(&p.Y, &p.X)
PP.Mul(&YplusX, &q.YminusX) // flipped sign
MM.Mul(&YminusX, &q.YplusX) // flipped sign
TT2d.Mul(&p.T, &q.T2d)
ZZ2.Mul(&p.Z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Sub(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Sub(&ZZ2, &TT2d) // flipped sign
v.T.Add(&ZZ2, &TT2d) // flipped sign
return v
}
func (v *ProjP1xP1) AddAffine(p *ProjP3, q *AffineCached) *ProjP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 radix51.FieldElement
YplusX.Add(&p.Y, &p.X)
YminusX.Sub(&p.Y, &p.X)
PP.Mul(&YplusX, &q.YplusX)
MM.Mul(&YminusX, &q.YminusX)
TT2d.Mul(&p.T, &q.T2d)
Z2.Add(&p.Z, &p.Z)
v.X.Sub(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&Z2, &TT2d)
v.T.Sub(&Z2, &TT2d)
return v
}
func (v *ProjP1xP1) SubAffine(p *ProjP3, q *AffineCached) *ProjP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 radix51.FieldElement
YplusX.Add(&p.Y, &p.X)
YminusX.Sub(&p.Y, &p.X)
PP.Mul(&YplusX, &q.YminusX) // flipped sign
MM.Mul(&YminusX, &q.YplusX) // flipped sign
TT2d.Mul(&p.T, &q.T2d)
Z2.Add(&p.Z, &p.Z)
v.X.Sub(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Sub(&Z2, &TT2d) // flipped sign
v.T.Add(&Z2, &TT2d) // flipped sign
return v
}
// Doubling.
func (v *ProjP1xP1) Double(p *ProjP2) *ProjP1xP1 {
var XX, YY, ZZ2, XplusYsq radix51.FieldElement
XX.Square(&p.X)
YY.Square(&p.Y)
ZZ2.Square(&p.Z)
ZZ2.Add(&ZZ2, &ZZ2)
XplusYsq.Add(&p.X, &p.Y)
XplusYsq.Square(&XplusYsq)
v.Y.Add(&YY, &XX)
v.Z.Sub(&YY, &XX)
v.X.Sub(&XplusYsq, &v.Y)
v.T.Sub(&ZZ2, &v.Z)
return v
}
// Negation.
func (v *ProjP3) Neg(p *ProjP3) *ProjP3 {
v.X.Neg(&p.X)
v.Y.Set(&p.Y)
v.Z.Set(&p.Z)
v.T.Neg(&p.T)
return v
}
// by @ebfull
// https://github.com/dalek-cryptography/curve25519-dalek/pull/226/files
func (v *ProjP3) Equal(u *ProjP3) int {
var t1, t2, t3, t4 radix51.FieldElement
t1.Mul(&v.X, &u.Z)
t2.Mul(&u.X, &v.Z)
t3.Mul(&v.Y, &u.Z)
t4.Mul(&u.Y, &v.Z)
return t1.Equal(&t2) & t3.Equal(&t4)
}
// Constant-time operations
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *ProjCached) Select(a, b *ProjCached, cond int) *ProjCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.Z.Select(&a.Z, &b.Z, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *AffineCached) Select(a, b *AffineCached, cond int) *AffineCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *ProjCached) CondNeg(cond int) *ProjCached {
radix51.CondSwap(&v.YplusX, &v.YminusX, cond)
v.T2d.CondNeg(&v.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *AffineCached) CondNeg(cond int) *AffineCached {
radix51.CondSwap(&v.YplusX, &v.YminusX, cond)
v.T2d.CondNeg(&v.T2d, cond)
return v
}

View File

@ -1,47 +0,0 @@
// Copyright 2019 Henry de Valence. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"testing"
)
func TestAddSubNegOnBasePoint(t *testing.T) {
var Bneg ProjP3
var tmpP2 ProjP2
var tmpP1xP1 ProjP1xP1
var tmpCached ProjCached
Bneg.Neg(&B)
var checkLhs, checkRhs, zero ProjP3
zero.Zero()
tmpCached.FromP3(&B)
tmpP1xP1.Add(&B, &tmpCached)
checkLhs.FromP1xP1(&tmpP1xP1)
tmpP2.FromP3(&B)
tmpP1xP1.Double(&tmpP2)
checkRhs.FromP1xP1(&tmpP1xP1)
if checkLhs.Equal(&checkRhs) != 1 {
t.Error("B + B != [2]B")
}
tmpCached.FromP3(&B)
tmpP1xP1.Sub(&B, &tmpCached)
checkLhs.FromP1xP1(&tmpP1xP1)
tmpCached.FromP3(&Bneg)
tmpP1xP1.Add(&B, &tmpCached)
checkRhs.FromP1xP1(&tmpP1xP1)
if checkLhs.Equal(&checkRhs) != 1 {
t.Error("B - B != B + (-B)")
}
if zero.Equal(&checkLhs) != 1 {
t.Error("B - B != 0")
}
if zero.Equal(&checkRhs) != 1 {
t.Error("B + (-B) != 0")
}
}

View File

@ -1,288 +0,0 @@
// Copyright (c) 2019 Henry de Valence.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"github.com/gtank/ristretto255/internal/scalar"
)
// Set v to x*B, where B is the Ed25519 basepoint, and return v.
//
// The scalar multiplication is done in constant time.
func (v *ProjP3) BasepointMul(x *scalar.Scalar) *ProjP3 {
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
// Group even and odd coefficients
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
//
// We use a lookup table for each i to get x_i*16^(2*i)*B
// and do four doublings to multiply by 16.
digits := x.SignedRadix16()
multiple := &AffineCached{}
tmp1 := &ProjP1xP1{}
tmp2 := &ProjP2{}
// Accumulate the odd components first
v.Zero()
for i := 1; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.FromP1xP1(tmp1)
}
// Multiply by 16
tmp2.FromP3(v) // tmp2 = v in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
v.FromP1xP1(tmp1) // now v = 16*(odd components)
// Accumulate the even components
for i := 0; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.FromP1xP1(tmp1)
}
return v
}
// Set v to x*Q, and return v. v and q may alias.
//
// The scalar multiplication is done in constant time.
func (v *ProjP3) ScalarMul(x *scalar.Scalar, q *ProjP3) *ProjP3 {
var table projLookupTable
table.FromP3(q)
// v and q could alias, but once the table is built we can clobber v.
v.Zero()
// Write x = sum(x_i * 16^i)
// so x*Q = sum( Q*x_i*16^i )
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
// <------compute inside out---------
//
// We use the lookup table to get the x_i*Q values
// and do four doublings to compute 16*Q
digits := x.SignedRadix16()
// Unwrap first loop iteration to save computing 16*identity
multiple := &ProjCached{}
tmp1 := &ProjP1xP1{}
tmp2 := &ProjP2{}
table.SelectInto(multiple, digits[63])
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
for i := 62; i >= 0; i-- {
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.FromP1xP1(tmp1) // v = 16*(prev) in P3 coords
table.SelectInto(multiple, digits[i])
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
}
v.FromP1xP1(tmp1)
return v
}
// Set v to the result of a multiscalar multiplication and return v.
//
// The multiscalar multiplication is sum(scalars[i]*points[i]).
//
// The multiscalar multiplication is performed in constant time.
func (v *ProjP3) MultiscalarMul(scalars []scalar.Scalar, points []*ProjP3) *ProjP3 {
if len(scalars) != len(points) {
panic("called MultiscalarMul with different size inputs")
}
// Proceed as in the single-base case, but share doublings
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
for i := range digits {
digits[i] = scalars[i].SignedRadix16()
}
// Unwrap first loop iteration to save computing 16*identity
multiple := &ProjCached{}
tmp1 := &ProjP1xP1{}
tmp2 := &ProjP2{}
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][63])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
v.FromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
for i := 62; i >= 0; i-- {
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.FromP1xP1(tmp1) // v = 16*(prev) in P3 coords
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][i])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
v.FromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
}
return v
}
// Set v to a*A + b*B, where B is the Ed25519 basepoint, and return v.
//
// The scalar multiplication is done in variable time.
func (v *ProjP3) VartimeDoubleBaseMul(a *scalar.Scalar, A *ProjP3, b *scalar.Scalar) *ProjP3 {
// Similarly to the single variable-base approach, we compute
// digits and use them with a lookup table. However, because
// we are allowed to do variable-time operations, we don't
// need constant-time lookups or constant-time digit
// computations.
//
// So we use a non-adjacent form of some width w instead of
// radix 16. This is like a binary representation (one digit
// for each binary place) but we allow the digits to grow in
// magnitude up to 2^{w-1} so that the nonzero digits are as
// sparse as possible. Intuitively, this "condenses" the
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
// corresponding to a bigger table.
aNaf := a.NonAdjacentForm(5)
bNaf := b.NonAdjacentForm(8)
// Find the first nonzero coefficient.
i := 255
for j := i; j >= 0; j-- {
if aNaf[j] != 0 || bNaf[j] != 0 {
break
}
}
multA := &ProjCached{}
multB := &AffineCached{}
tmp1 := &ProjP1xP1{}
tmp2 := &ProjP2{}
tmp2.Zero()
v.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
for ; i >= 0; i-- {
tmp1.Double(tmp2)
// Only update v if we have a nonzero coeff to add in.
if aNaf[i] > 0 {
v.FromP1xP1(tmp1)
aTable.SelectInto(multA, aNaf[i])
tmp1.Add(v, multA)
} else if aNaf[i] < 0 {
v.FromP1xP1(tmp1)
aTable.SelectInto(multA, -aNaf[i])
tmp1.Sub(v, multA)
}
if bNaf[i] > 0 {
v.FromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, bNaf[i])
tmp1.AddAffine(v, multB)
} else if bNaf[i] < 0 {
v.FromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, -bNaf[i])
tmp1.SubAffine(v, multB)
}
tmp2.FromP1xP1(tmp1)
}
v.FromP2(tmp2)
return v
}
// Set v to the result of a multiscalar multiplication and return v.
//
// The multiscalar multiplication is sum(scalars[i]*points[i]).
//
// The multiscalar multiplication is performed in variable time.
func (v *ProjP3) VartimeMultiscalarMul(scalars []scalar.Scalar, points []*ProjP3) *ProjP3 {
if len(scalars) != len(points) {
panic("called MultiscalarMul with different size inputs")
}
// Generalize double-base NAF computation to arbitrary sizes.
// Here all the points are dynamic, so we only use the smaller
// tables.
// Build lookup tables for each point
tables := make([]nafLookupTable5, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute a NAF for each scalar
nafs := make([][256]int8, len(scalars))
for i := range nafs {
nafs[i] = scalars[i].NonAdjacentForm(5)
}
multiple := &ProjCached{}
tmp1 := &ProjP1xP1{}
tmp2 := &ProjP2{}
tmp2.Zero()
v.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
//
// Skip trying to find the first nonzero coefficent, because
// searching might be more work than a few extra doublings.
for i := 255; i >= 0; i-- {
tmp1.Double(tmp2)
for j := range nafs {
if nafs[j][i] > 0 {
v.FromP1xP1(tmp1)
tables[j].SelectInto(multiple, nafs[j][i])
tmp1.Add(v, multiple)
} else if nafs[j][i] < 0 {
v.FromP1xP1(tmp1)
tables[j].SelectInto(multiple, -nafs[j][i])
tmp1.Sub(v, multiple)
}
}
tmp2.FromP1xP1(tmp1)
}
v.FromP2(tmp2)
return v
}

View File

@ -1,258 +0,0 @@
// Copyright 2019 Henry de Valence. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"testing"
"testing/quick"
"github.com/gtank/ristretto255/internal/radix51"
"github.com/gtank/ristretto255/internal/scalar"
)
// quickCheckConfig will make each quickcheck test run (2^6 * -quickchecks)
// times. The default value of -quickchecks is 100.
var (
quickCheckConfig = &quick.Config{MaxCountScale: 1 << 6}
// a random scalar generated using dalek.
dalekScalar = scalar.Scalar([32]byte{219, 106, 114, 9, 174, 249, 155, 89, 69, 203, 201, 93, 92, 116, 234, 187, 78, 115, 103, 172, 182, 98, 62, 103, 187, 136, 13, 100, 248, 110, 12, 4})
// the above, times the Ed25519 basepoint.
dalekScalarBasepoint = ProjP3{
X: radix51.FieldElement([5]uint64{778774234987948, 1589187156384239, 1213330452914652, 186161118421127, 2186284806803213}),
Y: radix51.FieldElement([5]uint64{1241255309069369, 1115278942994853, 1016511918109334, 1303231926552315, 1801448517689873}),
Z: radix51.FieldElement([5]uint64{353337085654440, 1327844406437681, 2207296012811921, 707394926933424, 917408459573183}),
T: radix51.FieldElement([5]uint64{585487439439725, 1792815221887900, 946062846079052, 1954901232609667, 1418300670001780}),
}
)
func TestScalarMulSmallScalars(t *testing.T) {
var z scalar.Scalar
var p, check ProjP3
p.ScalarMul(&z, &B)
check.Zero()
if check.Equal(&p) != 1 {
t.Error("0*B != 0")
}
z = scalar.Scalar([32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
p.ScalarMul(&z, &B)
check.Set(&B)
if check.Equal(&p) != 1 {
t.Error("1*B != 1")
}
}
func TestScalarMulVsDalek(t *testing.T) {
var p ProjP3
p.ScalarMul(&dalekScalar, &B)
if dalekScalarBasepoint.Equal(&p) != 1 {
t.Error("Scalar mul does not match dalek")
}
}
func TestBasepointMulVsDalek(t *testing.T) {
var p ProjP3
p.BasepointMul(&dalekScalar)
if dalekScalarBasepoint.Equal(&p) != 1 {
t.Error("Scalar mul does not match dalek")
}
}
func TestVartimeDoubleBaseMulVsDalek(t *testing.T) {
var p ProjP3
var z scalar.Scalar
p.VartimeDoubleBaseMul(&dalekScalar, &B, &z)
if dalekScalarBasepoint.Equal(&p) != 1 {
t.Error("VartimeDoubleBaseMul fails with b=0")
}
p.VartimeDoubleBaseMul(&z, &B, &dalekScalar)
if dalekScalarBasepoint.Equal(&p) != 1 {
t.Error("VartimeDoubleBaseMul fails with a=0")
}
}
func TestScalarMulDistributesOverAdd(t *testing.T) {
scalarMulDistributesOverAdd := func(x, y scalar.Scalar) bool {
// The quickcheck generation strategy chooses a random
// 32-byte array, but we require that the high bit is
// unset. FIXME: make Scalar opaque. Until then,
// mask the high bits:
x[31] &= 127
y[31] &= 127
var z scalar.Scalar
z.Add(&x, &y)
var p, q, r, check ProjP3
p.ScalarMul(&x, &B)
q.ScalarMul(&y, &B)
r.ScalarMul(&z, &B)
check.Add(&p, &q)
return check.Equal(&r) == 1
}
if err := quick.Check(scalarMulDistributesOverAdd, quickCheckConfig); err != nil {
t.Error(err)
}
}
func TestBasepointTableGeneration(t *testing.T) {
// The basepoint table is 32 affineLookupTables,
// corresponding to (16^2i)*B for table i.
tmp1 := &ProjP1xP1{}
tmp2 := &ProjP2{}
tmp3 := &ProjP3{}
tmp3.Set(&B)
table := make([]affineLookupTable, 32)
for i := 0; i < 32; i++ {
// Build the table
table[i].FromP3(tmp3)
// Assert equality with the hardcoded one
if table[i] != basepointTable[i] {
t.Errorf("Basepoint table %d does not match", i)
}
// Set p = (16^2)*p = 256*p = 2^8*p
tmp2.FromP3(tmp3)
for j := 0; j < 7; j++ {
tmp1.Double(tmp2)
tmp2.FromP1xP1(tmp1)
}
tmp1.Double(tmp2)
tmp3.FromP1xP1(tmp1)
}
}
func TestScalarMulMatchesBasepointMul(t *testing.T) {
scalarMulMatchesBasepointMul := func(x scalar.Scalar) bool {
// FIXME opaque scalars
x[31] &= 127
var p, q ProjP3
p.ScalarMul(&x, &B)
q.BasepointMul(&x)
return p.Equal(&q) == 1
}
if err := quick.Check(scalarMulMatchesBasepointMul, quickCheckConfig); err != nil {
t.Error(err)
}
}
func TestMultiScalarMulMatchesBasepointMul(t *testing.T) {
multiScalarMulMatchesBasepointMul := func(x, y, z scalar.Scalar) bool {
// FIXME opaque scalars
x[31] &= 127
y[31] &= 127
z[31] &= 127
var p, q1, q2, q3, check ProjP3
p.MultiscalarMul([]scalar.Scalar{x, y, z}, []*ProjP3{&B, &B, &B})
q1.BasepointMul(&x)
q2.BasepointMul(&y)
q3.BasepointMul(&z)
check.Zero()
check.Add(&q1, &q2).Add(&check, &q3)
return p.Equal(&check) == 1
}
if err := quick.Check(multiScalarMulMatchesBasepointMul, quickCheckConfig); err != nil {
t.Error(err)
}
}
func TestBasepointNafTableGeneration(t *testing.T) {
var table nafLookupTable8
table.FromP3(&B)
if table != basepointNafTable {
t.Error("BasepointNafTable does not match")
}
}
func TestVartimeDoubleBaseMulMatchesBasepointMul(t *testing.T) {
vartimeDoubleBaseMulMatchesBasepointMul := func(x, y scalar.Scalar) bool {
// FIXME opaque scalars
x[31] &= 127
y[31] &= 127
var p, q1, q2, check ProjP3
p.VartimeDoubleBaseMul(&x, &B, &y)
q1.BasepointMul(&x)
q2.BasepointMul(&y)
check.Zero()
check.Add(&q1, &q2)
return p.Equal(&check) == 1
}
if err := quick.Check(vartimeDoubleBaseMulMatchesBasepointMul, quickCheckConfig); err != nil {
t.Error(err)
}
}
func TestVartimeMultiScalarMulMatchesBasepointMul(t *testing.T) {
vartimeMultiScalarMulMatchesBasepointMul := func(x, y, z scalar.Scalar) bool {
// FIXME opaque scalars
x[31] &= 127
y[31] &= 127
z[31] &= 127
var p, q1, q2, q3, check ProjP3
p.VartimeMultiscalarMul([]scalar.Scalar{x, y, z}, []*ProjP3{&B, &B, &B})
q1.BasepointMul(&x)
q2.BasepointMul(&y)
q3.BasepointMul(&z)
check.Zero()
check.Add(&q1, &q2).Add(&check, &q3)
return p.Equal(&check) == 1
}
if err := quick.Check(vartimeMultiScalarMulMatchesBasepointMul, quickCheckConfig); err != nil {
t.Error(err)
}
}
// Benchmarks.
func BenchmarkBasepointMul(t *testing.B) {
var p ProjP3
for i := 0; i < t.N; i++ {
p.BasepointMul(&dalekScalar)
}
}
func BenchmarkScalarMul(t *testing.B) {
var p ProjP3
for i := 0; i < t.N; i++ {
p.ScalarMul(&dalekScalar, &B)
}
}
func BenchmarkVartimeDoubleBaseMul(t *testing.B) {
var p ProjP3
for i := 0; i < t.N; i++ {
p.VartimeDoubleBaseMul(&dalekScalar, &B, &dalekScalar)
}
}
func BenchmarkMultiscalarMulSize8(t *testing.B) {
var p ProjP3
x := dalekScalar
for i := 0; i < t.N; i++ {
p.MultiscalarMul([]scalar.Scalar{x, x, x, x, x, x, x, x}, []*ProjP3{&B, &B, &B, &B, &B, &B, &B, &B})
}
}
// TODO: add BenchmarkVartimeMultiscalarMulSize8 (need to have
// different scalars & points to measure cache effects).

File diff suppressed because one or more lines are too long

View File

@ -1,129 +0,0 @@
// Copyright (c) 2019 Henry de Valence.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/subtle"
)
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
points [8]ProjCached
}
// A precomputed lookup table for fixed-base, constant-time scalar muls.
type affineLookupTable struct {
points [8]AffineCached
}
// A dynamic lookup table for variable-base, variable-time scalar muls.
type nafLookupTable5 struct {
points [8]ProjCached
}
// A precomputed lookup table for fixed-base, variable-time scalar muls.
type nafLookupTable8 struct {
points [64]AffineCached
}
// Constructors.
// Builds a lookup table at runtime. Fast.
func (v *projLookupTable) FromP3(q *ProjP3) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := ProjP3{}
tmpP1xP1 := ProjP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to a ProjCached
// This is needlessly complicated because the API has explicit
// recievers instead of creating stack objects and relying on RVO
v.points[i+1].FromP3(tmpP3.FromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
}
}
// This is not optimised for speed; affine tables should be precomputed.
func (v *affineLookupTable) FromP3(q *ProjP3) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := ProjP3{}
tmpP1xP1 := ProjP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to AffineCached
v.points[i+1].FromP3(tmpP3.FromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
}
}
// Builds a lookup table at runtime. Fast.
func (v *nafLookupTable5) FromP3(q *ProjP3) {
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
v.points[0].FromP3(q)
q2 := ProjP3{}
q2.Add(q, q)
tmpP3 := ProjP3{}
tmpP1xP1 := ProjP1xP1{}
for i := 0; i < 7; i++ {
v.points[i+1].FromP3(tmpP3.FromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
}
}
// This is not optimised for speed; affine tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *ProjP3) {
v.points[0].FromP3(q)
q2 := ProjP3{}
q2.Add(q, q)
tmpP3 := ProjP3{}
tmpP1xP1 := ProjP1xP1{}
for i := 0; i < 63; i++ {
v.points[i+1].FromP3(tmpP3.FromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
}
}
// Selectors.
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *projLookupTable) SelectInto(dest *ProjCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *affineLookupTable) SelectInto(dest *AffineCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
func (v *nafLookupTable5) SelectInto(dest *ProjCached, x int8) {
*dest = v.points[x/2]
}
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
func (v *nafLookupTable8) SelectInto(dest *AffineCached, x int8) {
*dest = v.points[x/2]
}

View File

@ -1,125 +0,0 @@
// Copyright 2019 Henry de Valence. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"testing"
)
func TestProjLookupTable(t *testing.T) {
var table projLookupTable
table.FromP3(&B)
var tmp1, tmp2, tmp3 ProjCached
table.SelectInto(&tmp1, 6)
table.SelectInto(&tmp2, -2)
table.SelectInto(&tmp3, -4)
// Expect T1 + T2 + T3 = identity
var accP1xP1 ProjP1xP1
var accP3, check ProjP3
accP3.Zero()
check.Zero()
accP1xP1.Add(&accP3, &tmp1)
accP3.FromP1xP1(&accP1xP1)
accP1xP1.Add(&accP3, &tmp2)
accP3.FromP1xP1(&accP1xP1)
accP1xP1.Add(&accP3, &tmp3)
accP3.FromP1xP1(&accP1xP1)
if accP3.Equal(&check) != 1 {
t.Errorf("Sanity check on ProjLookupTable.SelectInto failed! %x %x %x", tmp1, tmp2, tmp3)
}
}
func TestAffineLookupTable(t *testing.T) {
var table affineLookupTable
table.FromP3(&B)
var tmp1, tmp2, tmp3 AffineCached
table.SelectInto(&tmp1, 3)
table.SelectInto(&tmp2, -7)
table.SelectInto(&tmp3, 4)
// Expect T1 + T2 + T3 = identity
var accP1xP1 ProjP1xP1
var accP3, check ProjP3
accP3.Zero()
check.Zero()
accP1xP1.AddAffine(&accP3, &tmp1)
accP3.FromP1xP1(&accP1xP1)
accP1xP1.AddAffine(&accP3, &tmp2)
accP3.FromP1xP1(&accP1xP1)
accP1xP1.AddAffine(&accP3, &tmp3)
accP3.FromP1xP1(&accP1xP1)
if accP3.Equal(&check) != 1 {
t.Errorf("Sanity check on ProjLookupTable.SelectInto failed! %x %x %x", tmp1, tmp2, tmp3)
}
}
func TestNafLookupTable5(t *testing.T) {
var table nafLookupTable5
table.FromP3(&B)
var tmp1, tmp2, tmp3, tmp4 ProjCached
table.SelectInto(&tmp1, 9)
table.SelectInto(&tmp2, 11)
table.SelectInto(&tmp3, 7)
table.SelectInto(&tmp4, 13)
// Expect T1 + T2 = T3 + T4
var accP1xP1 ProjP1xP1
var lhs, rhs ProjP3
lhs.Zero()
rhs.Zero()
accP1xP1.Add(&lhs, &tmp1)
lhs.FromP1xP1(&accP1xP1)
accP1xP1.Add(&lhs, &tmp2)
lhs.FromP1xP1(&accP1xP1)
accP1xP1.Add(&rhs, &tmp3)
rhs.FromP1xP1(&accP1xP1)
accP1xP1.Add(&rhs, &tmp4)
rhs.FromP1xP1(&accP1xP1)
if lhs.Equal(&rhs) != 1 {
t.Errorf("Sanity check on nafLookupTable5 failed")
}
}
func TestNafLookupTable8(t *testing.T) {
var table nafLookupTable8
table.FromP3(&B)
var tmp1, tmp2, tmp3, tmp4 AffineCached
table.SelectInto(&tmp1, 49)
table.SelectInto(&tmp2, 11)
table.SelectInto(&tmp3, 35)
table.SelectInto(&tmp4, 25)
// Expect T1 + T2 = T3 + T4
var accP1xP1 ProjP1xP1
var lhs, rhs ProjP3
lhs.Zero()
rhs.Zero()
accP1xP1.AddAffine(&lhs, &tmp1)
lhs.FromP1xP1(&accP1xP1)
accP1xP1.AddAffine(&lhs, &tmp2)
lhs.FromP1xP1(&accP1xP1)
accP1xP1.AddAffine(&rhs, &tmp3)
rhs.FromP1xP1(&accP1xP1)
accP1xP1.AddAffine(&rhs, &tmp4)
rhs.FromP1xP1(&accP1xP1)
if lhs.Equal(&rhs) != 1 {
t.Errorf("Sanity check on nafLookupTable8 failed")
}
}

View File

@ -1,128 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package radix51
import (
"testing"
"testing/quick"
)
func checkAliasingOneArg(f func(v, x *FieldElement) *FieldElement) func(v, x FieldElement) bool {
return func(v, x FieldElement) bool {
x1, v1 := x, x
// Calculate a reference f(x) without aliasing.
if out := f(&v, &x); out != &v && isInBounds(out) {
return false
}
// Test aliasing the argument and the receiver.
if out := f(&v1, &v1); out != &v1 || v1 != v {
return false
}
// Ensure the arguments was not modified.
return x == x1
}
}
func checkAliasingTwoArgs(f func(v, x, y *FieldElement) *FieldElement) func(v, x, y FieldElement) bool {
return func(v, x, y FieldElement) bool {
x1, y1, v1 := x, y, FieldElement{}
// Calculate a reference f(x, y) without aliasing.
if out := f(&v, &x, &y); out != &v && isInBounds(out) {
return false
}
// Test aliasing the first argument and the receiver.
v1 = x
if out := f(&v1, &v1, &y); out != &v1 || v1 != v {
return false
}
// Test aliasing the second argument and the receiver.
v1 = y
if out := f(&v1, &x, &v1); out != &v1 || v1 != v {
return false
}
// Calculate a reference f(x, x) without aliasing.
if out := f(&v, &x, &x); out != &v {
return false
}
// Test aliasing the first argument and the receiver.
v1 = x
if out := f(&v1, &v1, &x); out != &v1 || v1 != v {
return false
}
// Test aliasing the second argument and the receiver.
v1 = x
if out := f(&v1, &x, &v1); out != &v1 || v1 != v {
return false
}
// Test aliasing both arguments and the receiver.
v1 = x
if out := f(&v1, &v1, &v1); out != &v1 || v1 != v {
return false
}
// Ensure the arguments were not modified.
return x == x1 && y == y1
}
}
func TestAliasing(t *testing.T) {
type target struct {
name string
oneArgF func(v, x *FieldElement) *FieldElement
twoArgsF func(v, x, y *FieldElement) *FieldElement
}
for _, tt := range []target{
{name: "Abs", oneArgF: (*FieldElement).Abs},
{name: "Invert", oneArgF: (*FieldElement).Invert},
{name: "Neg", oneArgF: (*FieldElement).Neg},
{name: "Set", oneArgF: (*FieldElement).Set},
{name: "Square", oneArgF: (*FieldElement).Square},
{
name: "CondNeg0",
oneArgF: func(v, x *FieldElement) *FieldElement {
return (*FieldElement).CondNeg(v, x, 0)
},
},
{
name: "CondNeg1",
oneArgF: func(v, x *FieldElement) *FieldElement {
return (*FieldElement).CondNeg(v, x, 1)
},
},
{name: "Mul", twoArgsF: (*FieldElement).Mul},
{name: "Add", twoArgsF: (*FieldElement).Add},
{name: "Sub", twoArgsF: (*FieldElement).Sub},
{
name: "Select0",
twoArgsF: func(v, x, y *FieldElement) *FieldElement {
return (*FieldElement).Select(v, x, y, 0)
},
},
{
name: "Select1",
twoArgsF: func(v, x, y *FieldElement) *FieldElement {
return (*FieldElement).Select(v, x, y, 1)
},
},
} {
var err error
switch {
case tt.oneArgF != nil:
err = quick.Check(checkAliasingOneArg(tt.oneArgF), &quick.Config{MaxCountScale: 1 << 8})
case tt.twoArgsF != nil:
err = quick.Check(checkAliasingTwoArgs(tt.twoArgsF), &quick.Config{MaxCountScale: 1 << 8})
}
if err != nil {
t.Errorf("%v: %v", tt.name, err)
}
}
}

View File

@ -1,40 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package radix51_test
import (
"testing"
"github.com/gtank/ristretto255/internal/radix51"
)
func BenchmarkAdd(b *testing.B) {
var x, y radix51.FieldElement
x.One()
y.Add(radix51.One, radix51.One)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Add(&x, &y)
}
}
func BenchmarkMul(b *testing.B) {
var x, y radix51.FieldElement
x.One()
y.Add(radix51.One, radix51.One)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Mul(&x, &y)
}
}
func BenchmarkMul32(b *testing.B) {
var x radix51.FieldElement
x.One()
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Mul32(&x, 0xaa42aa42)
}
}

View File

@ -1,392 +0,0 @@
// Copyright (c) 2017 George Tankersley. All rights reserved.
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package radix51 implements GF(2^255-19) field arithmetic in radix 2^51
// representation. This code is a port of the public domain amd64-51-30k version
// of ed25519 from SUPERCOP.
//
// The interface works similarly to math/big.Int, and all arguments and
// receivers are allowed to alias.
package radix51
import (
"crypto/subtle"
"encoding/binary"
"math/big"
"math/bits"
)
// FieldElement represents an element of the field GF(2^255-19). An element t
// represents the integer t[0] + t[1]*2^51 + t[2]*2^102 + t[3]*2^153 +
// t[4]*2^204.
//
// Between operations, all limbs are expected to be lower than 2^51, except the
// first one, which can be up to 2^255 + 2^13 * 19 due to carry propagation.
//
// The zero value is a valid zero element.
type FieldElement [5]uint64
const maskLow51Bits uint64 = (1 << 51) - 1
var (
Zero = &FieldElement{0, 0, 0, 0, 0}
One = &FieldElement{1, 0, 0, 0, 0}
Two = &FieldElement{2, 0, 0, 0, 0}
MinusOne = new(FieldElement).Neg(One)
)
// Zero sets v = 0 and returns v.
func (v *FieldElement) Zero() *FieldElement {
*v = *Zero
return v
}
// One sets v = 1 and returns v.
func (v *FieldElement) One() *FieldElement {
*v = *One
return v
}
// carryPropagate brings the limbs below 52, 51, 51, 51, 51 bits. It is split in
// two because of the inliner heuristics. The two functions MUST be called one
// after the other.
func (v *FieldElement) carryPropagate1() *FieldElement {
v[1] += v[0] >> 51
v[0] &= maskLow51Bits
v[2] += v[1] >> 51
v[1] &= maskLow51Bits
v[3] += v[2] >> 51
v[2] &= maskLow51Bits
return v
}
func (v *FieldElement) carryPropagate2() *FieldElement {
v[4] += v[3] >> 51
v[3] &= maskLow51Bits
v[0] += (v[4] >> 51) * 19
v[4] &= maskLow51Bits
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *FieldElement) reduce() *FieldElement {
v.carryPropagate1().carryPropagate2()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v[0] + 19) >> 51
c = (v[1] + c) >> 51
c = (v[2] + c) >> 51
c = (v[3] + c) >> 51
c = (v[4] + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v[0] += 19 * c
v[1] += v[0] >> 51
v[0] = v[0] & maskLow51Bits
v[2] += v[1] >> 51
v[1] = v[1] & maskLow51Bits
v[3] += v[2] >> 51
v[2] = v[2] & maskLow51Bits
v[4] += v[3] >> 51
v[3] = v[3] & maskLow51Bits
// no additional carry
v[4] = v[4] & maskLow51Bits
return v
}
// Add sets v = a + b and returns v.
func (v *FieldElement) Add(a, b *FieldElement) *FieldElement {
v[0] = a[0] + b[0]
v[1] = a[1] + b[1]
v[2] = a[2] + b[2]
v[3] = a[3] + b[3]
v[4] = a[4] + b[4]
return v.carryPropagate1().carryPropagate2()
}
// Sub sets v = a - b and returns v.
func (v *FieldElement) Sub(a, b *FieldElement) *FieldElement {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v[0] = (a[0] + 0xFFFFFFFFFFFDA) - b[0]
v[1] = (a[1] + 0xFFFFFFFFFFFFE) - b[1]
v[2] = (a[2] + 0xFFFFFFFFFFFFE) - b[2]
v[3] = (a[3] + 0xFFFFFFFFFFFFE) - b[3]
v[4] = (a[4] + 0xFFFFFFFFFFFFE) - b[4]
return v.carryPropagate1().carryPropagate2()
}
// Neg sets v = -a and returns v.
func (v *FieldElement) Neg(a *FieldElement) *FieldElement {
return v.Sub(Zero, a)
}
// Invert sets v = 1/z mod p and returns v.
func (v *FieldElement) Invert(z *FieldElement) *FieldElement {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t FieldElement
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Mul(&t, z) // 9
z11.Mul(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Mul(&t, &z9) // 2^5 - 2^0 = 31
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Mul(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Mul(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Mul(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Mul(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Mul(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Mul(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Mul(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Mul(&t, &z11) // 2^255 - 21
}
// Set sets v = a and returns v.
func (v *FieldElement) Set(a *FieldElement) *FieldElement {
*v = *a
return v
}
// FromBytes sets v to x, which must be a 32 bytes little-endian encoding.
//
// Consistently with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted.
func (v *FieldElement) FromBytes(x []byte) *FieldElement {
if len(x) != 32 {
panic("ed25519: invalid field element input size")
}
// Provide headroom for the slight binary.LittleEndian.Uint64 overread. (We
// read 64 bits at an offset of 200, but then take only 4+51 into account.)
var buf [33]byte
copy(buf[:], x)
for i := range v {
bitsOffset := i * 51
v[i] = binary.LittleEndian.Uint64(buf[bitsOffset/8:])
v[i] >>= uint(bitsOffset % 8)
v[i] &= maskLow51Bits
}
return v
}
// Bytes appends a 32 bytes little-endian encoding of v to b.
func (v *FieldElement) Bytes(b []byte) []byte {
t := *v
t.reduce()
res, out := sliceForAppend(b, 32)
for i := range out {
out[i] = 0
}
var buf [8]byte
for i := range t {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], t[i]<<uint(bitsOffset%8))
for i, b := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= b
}
}
return res
}
// sliceForAppend extends the input slice by n bytes. head is the full extended
// slice, while tail is the appended part. If the original slice has sufficient
// capacity no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// FromBig sets v = n and returns v. The bit length of n must not exceed 256.
func (v *FieldElement) FromBig(n *big.Int) *FieldElement {
if n.BitLen() > 32*8 {
panic("ed25519: invalid field element input size")
}
buf := make([]byte, 0, 32)
for _, word := range n.Bits() {
for i := 0; i < bits.UintSize; i += 8 {
if len(buf) >= cap(buf) {
break
}
buf = append(buf, byte(word))
word >>= 8
}
}
return v.FromBytes(buf[:32])
}
// ToBig returns v as a big.Int.
func (v *FieldElement) ToBig() *big.Int {
buf := v.Bytes(nil)
words := make([]big.Word, 32*8/bits.UintSize)
for n := range words {
for i := 0; i < bits.UintSize; i += 8 {
if len(buf) == 0 {
break
}
words[n] |= big.Word(buf[0]) << big.Word(i)
buf = buf[1:]
}
}
return new(big.Int).SetBits(words)
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *FieldElement) Equal(u *FieldElement) int {
var sa, sv [32]byte
u.Bytes(sa[:0])
v.Bytes(sv[:0])
return subtle.ConstantTimeCompare(sa[:], sv[:])
}
const mask64Bits uint64 = (1 << 64) - 1
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *FieldElement) Select(a, b *FieldElement, cond int) *FieldElement {
m := uint64(cond) * mask64Bits
v[0] = (m & a[0]) | (^m & b[0])
v[1] = (m & a[1]) | (^m & b[1])
v[2] = (m & a[2]) | (^m & b[2])
v[3] = (m & a[3]) | (^m & b[3])
v[4] = (m & a[4]) | (^m & b[4])
return v
}
// CondSwap swaps a and b if cond == 1 or leaves them unchanged if cond == 0.
func CondSwap(a, b *FieldElement, cond int) {
m := uint64(cond) * mask64Bits
t := m & (a[0] ^ b[0])
a[0] ^= t
b[0] ^= t
t = m & (a[1] ^ b[1])
a[1] ^= t
b[1] ^= t
t = m & (a[2] ^ b[2])
a[2] ^= t
b[2] ^= t
t = m & (a[3] ^ b[3])
a[3] ^= t
b[3] ^= t
t = m & (a[4] ^ b[4])
a[4] ^= t
b[4] ^= t
}
// CondNeg sets v to -u if cond == 1, and to u if cond == 0.
func (v *FieldElement) CondNeg(u *FieldElement, cond int) *FieldElement {
tmp := new(FieldElement).Neg(u)
return v.Select(tmp, u, cond)
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *FieldElement) IsNegative() int {
var b [32]byte
v.Bytes(b[:0])
return int(b[0] & 1)
}
// Abs sets v to |u| and returns v.
func (v *FieldElement) Abs(u *FieldElement) *FieldElement {
return v.CondNeg(u, u.IsNegative())
}
// Mul sets v = x * y and returns v.
func (v *FieldElement) Mul(x, y *FieldElement) *FieldElement {
feMul(v, x, y)
return v
}
// Square sets v = x * x and returns v.
func (v *FieldElement) Square(x *FieldElement) *FieldElement {
feSquare(v, x)
return v
}
// Mul32 sets v = x * y and returns v.
func (v *FieldElement) Mul32(x *FieldElement, y uint32) *FieldElement {
x0lo, x0hi := mul51(x[0], y)
x1lo, x1hi := mul51(x[1], y)
x2lo, x2hi := mul51(x[2], y)
x3lo, x3hi := mul51(x[3], y)
x4lo, x4hi := mul51(x[4], y)
v[0] = x0lo + 19*x4hi // carried over per the reduction identity
v[1] = x1lo + x0hi
v[2] = x2lo + x1hi
v[3] = x3lo + x2hi
v[4] = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}

View File

@ -1,13 +0,0 @@
// Copyright (c) 2017 George Tankersley. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build amd64,!purego
package radix51
//go:noescape
func feMul(out, a, b *FieldElement)
//go:noescape
func feSquare(out, x *FieldElement)

View File

@ -1,348 +0,0 @@
// Copyright (c) 2017 George Tankersley. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build amd64,!purego
// func feMul(out, a, b *FieldElement)
TEXT ·feMul(SB),$0-24
// Based on assembly generated by PeachPy. Equivalent to the Go in
// feMulGeneric, which was originally based on the amd64-51-30k
// assembly in SUPERCOP.
MOVQ out+0(FP), DI
MOVQ a+8(FP), BX
MOVQ b+16(FP), CX
// Calculate r0
MOVQ 0(BX), AX // rax <-- x0
MULQ 0(CX) // rdx, rax <-- x0*y0
MOVQ AX, SI // r00 = rax
MOVQ DX, BP // r01 = rdx
MOVQ 8(BX), DX // rdx <-- x1
IMUL3Q $19, DX, AX // rax <-- x1*19
MULQ 32(CX) // rdx, rax <-- x1_19*y4
ADDQ AX, SI // r00 += rax
ADCQ DX, BP // r01 += rdx
MOVQ 16(BX), DX // rdx <-- x2
IMUL3Q $19, DX, AX // rax <-- x2*19
MULQ 24(CX) // rdx, rax <-- x2_19*y3
ADDQ AX, SI // r00 += rax
ADCQ DX, BP // r01 += rdx
MOVQ 24(BX), DX // rdx <-- x3
IMUL3Q $19, DX, AX // rax <-- x3*19
MULQ 16(CX) // rdx, rax <-- x3_19 * y2
ADDQ AX, SI // r00 += rax
ADCQ DX, BP // r01 += rdx
MOVQ 32(BX), DX // rdx <-- x4
IMUL3Q $19, DX, AX // rax <-- x4*19
MULQ 8(CX) // rdx rax <-- x4_19*y1
ADDQ AX, SI // r00 += rax
ADCQ DX, BP // r01 += rdx
// Calculate r1
MOVQ 0(BX), AX
MULQ 8(CX)
MOVQ AX, R8 // r10
MOVQ DX, R9 // r11
MOVQ 8(BX), AX
MULQ 0(CX)
ADDQ AX, R8
ADCQ DX, R9
MOVQ 16(BX), DX
IMUL3Q $19, DX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, R9
MOVQ 24(BX), DX
IMUL3Q $19, DX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, R9
MOVQ 32(BX), DX
IMUL3Q $19, DX, AX
MULQ 16(CX)
ADDQ AX, R8
ADCQ DX, R9
// Calculate r2
MOVQ 0(BX), AX
MULQ 16(CX)
MOVQ AX, R10 // r20
MOVQ DX, R11 // r21
MOVQ 8(BX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R11
MOVQ 16(BX), AX
MULQ 0(CX)
ADDQ AX, R10
ADCQ DX, R11
MOVQ 24(BX), DX
IMUL3Q $19, DX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R11
MOVQ 32(BX), DX
IMUL3Q $19, DX, AX
MULQ 24(CX)
ADDQ AX, R10
ADCQ DX, R11
// Calculate r3
MOVQ 0(BX), AX
MULQ 24(CX)
MOVQ AX, R12 // r30
MOVQ DX, R13 // r31
MOVQ 8(BX), AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R13
MOVQ 16(BX), AX
MULQ 8(CX)
ADDQ AX, R12
ADCQ DX, R13
MOVQ 24(BX), AX
MULQ 0(CX)
ADDQ AX, R12
ADCQ DX, R13
MOVQ 32(BX), DX
IMUL3Q $19, DX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R13
// Calculate r4
MOVQ 0(BX), AX
MULQ 32(CX)
MOVQ AX, R14 // r40
MOVQ DX, R15 // r41
MOVQ 8(BX), AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R15
MOVQ 16(BX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R15
MOVQ 24(BX), AX
MULQ 8(CX)
ADDQ AX, R14
ADCQ DX, R15
MOVQ 32(BX), AX
MULQ 0(CX)
ADDQ AX, R14
ADCQ DX, R15
MOVQ $2251799813685247, AX // (1<<51) - 1
SHLQ $13, SI, BP // r01 = shld with r00
ANDQ AX, SI // r00 &= mask51
SHLQ $13, R8, R9 // r11 = shld with r10
ANDQ AX, R8 // r10 &= mask51
ADDQ BP, R8 // r10 += r01
SHLQ $13, R10, R11 // r21 = shld with r20
ANDQ AX, R10 // r20 &= mask51
ADDQ R9, R10 // r20 += r11
SHLQ $13, R12, R13 // r31 = shld with r30
ANDQ AX, R12 // r30 &= mask51
ADDQ R11, R12 // r30 += r21
SHLQ $13, R14, R15 // r41 = shld with r40
ANDQ AX, R14 // r40 &= mask51
ADDQ R13, R14 // r40 += r31
IMUL3Q $19, R15, R15 // r41 = r41*19
ADDQ R15, SI // r00 += r41
MOVQ SI, DX // rdx <-- r00
SHRQ $51, DX // rdx <-- r00 >> 51
ADDQ DX, R8 // r10 += r00 >> 51
MOVQ R8, DX // rdx <-- r10
SHRQ $51, DX // rdx <-- r10 >> 51
ANDQ AX, SI // r00 &= mask51
ADDQ DX, R10 // r20 += r10 >> 51
MOVQ R10, DX // rdx <-- r20
SHRQ $51, DX // rdx <-- r20 >> 51
ANDQ AX, R8 // r10 &= mask51
ADDQ DX, R12 // r30 += r20 >> 51
MOVQ R12, DX // rdx <-- r30
SHRQ $51, DX // rdx <-- r30 >> 51
ANDQ AX, R10 // r20 &= mask51
ADDQ DX, R14 // r40 += r30 >> 51
MOVQ R14, DX // rdx <-- r40
SHRQ $51, DX // rdx <-- r40 >> 51
ANDQ AX, R12 // r30 &= mask51
IMUL3Q $19, DX, DX // rdx <-- (r40 >> 51) * 19
ADDQ DX, SI // r00 += (r40 >> 51) *19
ANDQ AX, R14 // r40 &= mask51
MOVQ SI, 0(DI)
MOVQ R8, 8(DI)
MOVQ R10, 16(DI)
MOVQ R12, 24(DI)
MOVQ R14, 32(DI)
RET
// func feSquare(out, x *FieldElement)
TEXT ·feSquare(SB),4,$0-16
MOVQ out+0(FP), DI
MOVQ x+8(FP), SI
// r0 = x0*x0 + x1*38*x4 + x2*38*x3
MOVQ 0(SI), AX
MULQ 0(SI)
MOVQ AX, CX // r00
MOVQ DX, R8 // r01
MOVQ 8(SI), DX
IMUL3Q $38, DX, AX
MULQ 32(SI)
ADDQ AX, CX
ADCQ DX, R8
MOVQ 16(SI), DX
IMUL3Q $38, DX, AX
MULQ 24(SI)
ADDQ AX, CX
ADCQ DX, R8
// r1 = x0*2*x1 + x2*38*x4 + x3*19*x3
MOVQ 0(SI), AX
SHLQ $1, AX
MULQ 8(SI)
MOVQ AX, R9 // r10
MOVQ DX, R10 // r11
MOVQ 16(SI), DX
IMUL3Q $38, DX, AX
MULQ 32(SI)
ADDQ AX, R9
ADCQ DX, R10
MOVQ 24(SI), DX
IMUL3Q $19, DX, AX
MULQ 24(SI)
ADDQ AX, R9
ADCQ DX, R10
// r2 = x0*2*x2 + x1*x1 + x3*38*x4
MOVQ 0(SI), AX
SHLQ $1, AX
MULQ 16(SI)
MOVQ AX, R11 // r20
MOVQ DX, R12 // r21
MOVQ 8(SI), AX
MULQ 8(SI)
ADDQ AX, R11
ADCQ DX, R12
MOVQ 24(SI), DX
IMUL3Q $38, DX, AX
MULQ 32(SI)
ADDQ AX, R11
ADCQ DX, R12
// r3 = x0*2*x3 + x1*2*x2 + x4*19*x4
MOVQ 0(SI), AX
SHLQ $1, AX
MULQ 24(SI)
MOVQ AX, R13 // r30
MOVQ DX, R14 // r31
MOVQ 8(SI), AX
SHLQ $1, AX
MULQ 16(SI)
ADDQ AX, R13
ADCQ DX, R14
MOVQ 32(SI), DX
IMUL3Q $19, DX, AX
MULQ 32(SI)
ADDQ AX, R13
ADCQ DX, R14
// r4 = x0*2*x4 + x1*2*x3 + x2*x2
MOVQ 0(SI), AX
SHLQ $1, AX
MULQ 32(SI)
MOVQ AX, R15 // r40
MOVQ DX, BX // r41
MOVQ 8(SI), AX
SHLQ $1, AX
MULQ 24(SI)
ADDQ AX, R15
ADCQ DX, BX
MOVQ 16(SI), AX
MULQ 16(SI)
ADDQ AX, R15
ADCQ DX, BX
// Reduce
MOVQ $2251799813685247, AX // (1<<51) - 1
SHLQ $13, CX, R8 // r01 = shld with r00
ANDQ AX, CX // r00 &= mask51
SHLQ $13, R9, R10 // r11 = shld with r10
ANDQ AX, R9 // r10 &= mask51
ADDQ R8, R9 // r10 += r01
SHLQ $13, R11, R12 // r21 = shld with r20
ANDQ AX, R11 // r20 &= mask51
ADDQ R10, R11 // r20 += r11
SHLQ $13, R13, R14 // r31 = shld with r30
ANDQ AX, R13 // r30 &= mask51
ADDQ R12, R13 // r30 += r21
SHLQ $13, R15, BX // r41 = shld with r40
ANDQ AX, R15 // r40 &= mask51
ADDQ R14, R15 // r40 += r31
IMUL3Q $19, BX, DX // r41 = r41*19
ADDQ DX, CX // r00 += r41
MOVQ CX, DX // rdx <-- r00
SHRQ $51, DX // rdx <-- r00 >> 51
ADDQ DX, R9 // r10 += r00 >> 51
MOVQ R9, DX // rdx <-- r10
SHRQ $51, DX // rdx <-- r10 >> 51
ANDQ AX, CX // r00 &= mask51
ADDQ DX, R11 // r20 += r10 >> 51
MOVQ R11, DX // rdx <-- r20
SHRQ $51, DX // rdx <-- r20 >> 51
ANDQ AX, R9 // r10 &= mask51
ADDQ DX, R13 // r30 += r20 >> 51
MOVQ R13, DX // rdx <-- r30
SHRQ $51, DX // rdx <-- r30 >> 51
ANDQ AX, R11 // r20 &= mask51
ADDQ DX, R15 // r40 += r30 >> 51
MOVQ R15, DX // rdx <-- r40
SHRQ $51, DX // rdx <-- r40 >> 51
ANDQ AX, R13 // r30 &= mask51
IMUL3Q $19, DX, DX // rdx <-- (r40 >> 51) * 19
ADDQ DX, CX // r00 += (r40 >> 51) *19
ANDQ AX, R15 // r40 &= mask51
MOVQ CX, 0(DI)
MOVQ R9, 8(DI)
MOVQ R11, 16(DI)
MOVQ R13, 24(DI)
MOVQ R15, 32(DI)
RET

View File

@ -1,179 +0,0 @@
// Copyright (c) 2017 George Tankersley. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package radix51
func feMulGeneric(v, x, y *FieldElement) {
x0 := x[0]
x1 := x[1]
x2 := x[2]
x3 := x[3]
x4 := x[4]
y0 := y[0]
y1 := y[1]
y2 := y[2]
y3 := y[3]
y4 := y[4]
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute a coefficient r_5 . Whenever the result of a
// mul instruction belongs to r_5 , for example in the multiplication of
// x_3*y_2 , we multiply one of the inputs by 19 and add the result to r_0.
x1_19 := x1 * 19
x2_19 := x2 * 19
x3_19 := x3 * 19
x4_19 := x4 * 19
// calculate r0 = x0*y0 + 19*(x1*y4 + x2*y3 + x3*y2 + x4*y1)
r00, r01 := madd64(0, 0, x0, y0)
r00, r01 = madd64(r00, r01, x1_19, y4)
r00, r01 = madd64(r00, r01, x2_19, y3)
r00, r01 = madd64(r00, r01, x3_19, y2)
r00, r01 = madd64(r00, r01, x4_19, y1)
// calculate r1 = x0*y1 + x1*y0 + 19*(x2*y4 + x3*y3 + x4*y2)
r10, r11 := madd64(0, 0, x0, y1)
r10, r11 = madd64(r10, r11, x1, y0)
r10, r11 = madd64(r10, r11, x2_19, y4)
r10, r11 = madd64(r10, r11, x3_19, y3)
r10, r11 = madd64(r10, r11, x4_19, y2)
// calculate r2 = x0*y2 + x1*y1 + x2*y0 + 19*(x3*y4 + x4*y3)
r20, r21 := madd64(0, 0, x0, y2)
r20, r21 = madd64(r20, r21, x1, y1)
r20, r21 = madd64(r20, r21, x2, y0)
r20, r21 = madd64(r20, r21, x3_19, y4)
r20, r21 = madd64(r20, r21, x4_19, y3)
// calculate r3 = x0*y3 + x1*y2 + x2*y1 + x3*y0 + 19*x4*y4
r30, r31 := madd64(0, 0, x0, y3)
r30, r31 = madd64(r30, r31, x1, y2)
r30, r31 = madd64(r30, r31, x2, y1)
r30, r31 = madd64(r30, r31, x3, y0)
r30, r31 = madd64(r30, r31, x4_19, y4)
// calculate r4 = x0*y4 + x1*y3 + x2*y2 + x3*y1 + x4*y0
r40, r41 := madd64(0, 0, x0, y4)
r40, r41 = madd64(r40, r41, x1, y3)
r40, r41 = madd64(r40, r41, x2, y2)
r40, r41 = madd64(r40, r41, x3, y1)
r40, r41 = madd64(r40, r41, x4, y0)
// After the multiplication we need to reduce (carry) the 5 coefficients to
// obtain a result with coefficients that are at most slightly larger than
// 2^51 . Denote the two registers holding coefficient r_0 as r_00 and r_01
// with r_0 = 2^64*r_01 + r_00 . Similarly denote the two registers holding
// coefficient r_1 as r_10 and r_11 . We first shift r_01 left by 13, while
// shifting in the most significant bits of r_00 (shld instruction) and
// then compute the logical and of r_00 with 2^51 1. We do the same with
// r_10 and r_11 and add r_01 into r_10 after the logical and with 2^51
// 1. We proceed this way for coefficients r_2,...,r_4; register r_41 is
// multiplied by 19 before adding it to r_00 .
r01 = (r01 << 13) | (r00 >> 51)
r00 &= maskLow51Bits
r11 = (r11 << 13) | (r10 >> 51)
r10 &= maskLow51Bits
r10 += r01
r21 = (r21 << 13) | (r20 >> 51)
r20 &= maskLow51Bits
r20 += r11
r31 = (r31 << 13) | (r30 >> 51)
r30 &= maskLow51Bits
r30 += r21
r41 = (r41 << 13) | (r40 >> 51)
r40 &= maskLow51Bits
r40 += r31
r41 *= 19
r00 += r41
// Now all 5 coefficients fit into 64-bit registers but are still too large
// to be used as input to another multiplication. We therefore carry from
// r_0 to r_1 , from r_1 to r_2 , from r_2 to r_3 , from r_3 to r_4 , and
// finally from r_4 to r_0 . Each of these carries is done as one copy, one
// right shift by 51, one logical and with 2^51 1, and one addition.
*v = FieldElement{r00, r10, r20, r30, r40}
v.carryPropagate1().carryPropagate2()
}
func feSquareGeneric(v, x *FieldElement) {
// Squaring needs only 15 mul instructions. Some inputs are multiplied by 2;
// this is combined with multiplication by 19 where possible. The coefficient
// reduction after squaring is the same as for multiplication.
x0 := x[0]
x1 := x[1]
x2 := x[2]
x3 := x[3]
x4 := x[4]
x0_2 := x0 << 1
x1_2 := x1 << 1
x1_38 := x1 * 38
x2_38 := x2 * 38
x3_38 := x3 * 38
x3_19 := x3 * 19
x4_19 := x4 * 19
// r0 = x0*x0 + x1*38*x4 + x2*38*x3
r00, r01 := madd64(0, 0, x0, x0)
r00, r01 = madd64(r00, r01, x1_38, x4)
r00, r01 = madd64(r00, r01, x2_38, x3)
// r1 = x0*2*x1 + x2*38*x4 + x3*19*x3
r10, r11 := madd64(0, 0, x0_2, x1)
r10, r11 = madd64(r10, r11, x2_38, x4)
r10, r11 = madd64(r10, r11, x3_19, x3)
// r2 = x0*2*x2 + x1*x1 + x3*38*x4
r20, r21 := madd64(0, 0, x0_2, x2)
r20, r21 = madd64(r20, r21, x1, x1)
r20, r21 = madd64(r20, r21, x3_38, x4)
// r3 = x0*2*x3 + x1*2*x2 + x4*19*x4
r30, r31 := madd64(0, 0, x0_2, x3)
r30, r31 = madd64(r30, r31, x1_2, x2)
r30, r31 = madd64(r30, r31, x4_19, x4)
// r4 = x0*2*x4 + x1*2*x3 + x2*x2
r40, r41 := madd64(0, 0, x0_2, x4)
r40, r41 = madd64(r40, r41, x1_2, x3)
r40, r41 = madd64(r40, r41, x2, x2)
// Same reduction
r01 = (r01 << 13) | (r00 >> 51)
r00 &= maskLow51Bits
r11 = (r11 << 13) | (r10 >> 51)
r10 &= maskLow51Bits
r10 += r01
r21 = (r21 << 13) | (r20 >> 51)
r20 &= maskLow51Bits
r20 += r11
r31 = (r31 << 13) | (r30 >> 51)
r30 &= maskLow51Bits
r30 += r21
r41 = (r41 << 13) | (r40 >> 51)
r40 &= maskLow51Bits
r40 += r31
r41 *= 19
r00 += r41
*v = FieldElement{r00, r10, r20, r30, r40}
v.carryPropagate1().carryPropagate2()
}

View File

@ -1,11 +0,0 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64 purego
package radix51
func feMul(v, x, y *FieldElement) { feMulGeneric(v, x, y) }
func feSquare(v, x *FieldElement) { feSquareGeneric(v, x) }

View File

@ -1,395 +0,0 @@
// Copyright (c) 2017 George Tankersley. All rights reserved.
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package radix51
import (
"bytes"
"crypto/rand"
"io"
"math/big"
"math/bits"
mathrand "math/rand"
"reflect"
"testing"
"testing/quick"
)
// quickCheckConfig will make each quickcheck test run (1024 * -quickchecks)
// times. The default value of -quickchecks is 100.
var quickCheckConfig = &quick.Config{MaxCountScale: 1 << 10}
func generateFieldElement(rand *mathrand.Rand) FieldElement {
// Generation strategy: generate random limb values of [52, 51, 51, 51, 51]
// bits, like the ones returned by lightReduce.
const maskLow52Bits = (1 << 52) - 1
return FieldElement{
rand.Uint64() & maskLow52Bits,
rand.Uint64() & maskLow51Bits,
rand.Uint64() & maskLow51Bits,
rand.Uint64() & maskLow51Bits,
rand.Uint64() & maskLow51Bits,
}
}
// weirdLimbs can be combined to generate a range of edge-case field elements.
// 0 and -1 are intentionally more weighted, as they combine well.
var (
weirdLimbs51 = []uint64{
0, 0, 0, 0,
1,
19 - 1,
19,
0x2aaaaaaaaaaaa,
0x5555555555555,
(1 << 51) - 20,
(1 << 51) - 19,
(1 << 51) - 1, (1 << 51) - 1,
(1 << 51) - 1, (1 << 51) - 1,
}
weirdLimbs52 = []uint64{
0, 0, 0, 0, 0, 0,
1,
19 - 1,
19,
0x2aaaaaaaaaaaa,
0x5555555555555,
(1 << 51) - 20,
(1 << 51) - 19,
(1 << 51) - 1, (1 << 51) - 1,
(1 << 51) - 1, (1 << 51) - 1,
(1 << 51) - 1, (1 << 51) - 1,
1 << 51,
(1 << 51) + 1,
(1 << 52) - 19,
(1 << 52) - 1,
}
)
func generateWeirdFieldElement(rand *mathrand.Rand) FieldElement {
return FieldElement{
weirdLimbs52[rand.Intn(len(weirdLimbs52))],
weirdLimbs51[rand.Intn(len(weirdLimbs51))],
weirdLimbs51[rand.Intn(len(weirdLimbs51))],
weirdLimbs51[rand.Intn(len(weirdLimbs51))],
weirdLimbs51[rand.Intn(len(weirdLimbs51))],
}
}
func (x FieldElement) Generate(rand *mathrand.Rand, size int) reflect.Value {
if rand.Intn(2) == 0 {
return reflect.ValueOf(generateWeirdFieldElement(rand))
}
return reflect.ValueOf(generateFieldElement(rand))
}
// isInBounds returns whether the element is within the expected bit size bounds
// after a light reduction.
func isInBounds(x *FieldElement) bool {
return bits.Len64(x[0]) <= 52 &&
bits.Len64(x[1]) <= 51 &&
bits.Len64(x[2]) <= 51 &&
bits.Len64(x[3]) <= 51 &&
bits.Len64(x[4]) <= 51
}
func TestMulDistributesOverAdd(t *testing.T) {
mulDistributesOverAdd := func(x, y, z FieldElement) bool {
// Compute t1 = (x+y)*z
t1 := new(FieldElement)
t1.Add(&x, &y)
t1.Mul(t1, &z)
// Compute t2 = x*z + y*z
t2 := new(FieldElement)
t3 := new(FieldElement)
t2.Mul(&x, &z)
t3.Mul(&y, &z)
t2.Add(t2, t3)
return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
}
if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil {
t.Error(err)
}
}
func TestMul64to128(t *testing.T) {
a := uint64(5)
b := uint64(5)
r0, r1 := madd64(0, 0, a, b)
if r0 != 0x19 || r1 != 0 {
t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r0, r1)
}
a = uint64(18014398509481983) // 2^54 - 1
b = uint64(18014398509481983) // 2^54 - 1
r0, r1 = madd64(0, 0, a, b)
if r0 != 0xff80000000000001 || r1 != 0xfffffffffff {
t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r0, r1)
}
a = uint64(1125899906842661)
b = uint64(2097155)
r0, r1 = madd64(0, 0, a, b)
r0, r1 = madd64(r0, r1, a, b)
r0, r1 = madd64(r0, r1, a, b)
r0, r1 = madd64(r0, r1, a, b)
r0, r1 = madd64(r0, r1, a, b)
if r0 != 16888498990613035 || r1 != 640 {
t.Errorf("wrong answer: %d + %d*(2**64)", r0, r1)
}
}
var r0, r1 uint64
func BenchmarkWideMultCall(t *testing.B) {
a := uint64(18014398509481983)
b := uint64(18014398509481983)
for i := 0; i < t.N; i++ {
r0, r1 = madd64(r0, r1, a, b)
}
}
func TestFromBytesRoundTrip(t *testing.T) {
f1 := func(in, out [32]byte, fe FieldElement) bool {
fe.FromBytes(in[:])
fe.Bytes(out[:0])
// Mask the most significant bit as it's ignored by FromBytes. (Now
// instead of earlier so we check the masking in FromBytes is working.)
in[len(in)-1] &= (1 << 7) - 1
// TODO: values in the range [2^255-19, 2^255-1] will still fail the
// comparison as they will have been reduced in the round-trip, but the
// current quickcheck generation strategy will never hit them, which is
// not good. We should have a weird generator that aims for edge cases,
// and we'll know it works when this test breaks.
return bytes.Equal(in[:], out[:]) && isInBounds(&fe)
}
if err := quick.Check(f1, nil); err != nil {
t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
}
f2 := func(fe, r FieldElement, out [32]byte) bool {
fe.Bytes(out[:0])
r.FromBytes(out[:])
// Intentionally not using Equal not to go through Bytes again.
// Calling reduce because both Generate and FromBytes can produce
// non-canonical representations.
fe.reduce()
r.reduce()
return fe == r
}
if err := quick.Check(f2, nil); err != nil {
t.Errorf("failed FE->bytes->FE round-trip: %v", err)
}
// Check some fixed vectors from dalek
type feRTTest struct {
fe FieldElement
b []byte
}
var tests = []feRTTest{
{
fe: FieldElement([5]uint64{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}),
b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
},
{
fe: FieldElement([5]uint64{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}),
b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
},
}
for _, tt := range tests {
if !bytes.Equal(tt.fe.Bytes(nil), tt.b) || new(FieldElement).FromBytes(tt.b).Equal(&tt.fe) != 1 {
t.Errorf("Failed fixed roundtrip: %v", tt)
}
}
}
func swapEndianness(buf []byte) []byte {
for i := 0; i < len(buf)/2; i++ {
buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
}
return buf
}
func TestBytesBigEquivalence(t *testing.T) {
f1 := func(in, out [32]byte, fe, fe1 FieldElement) bool {
fe.FromBytes(in[:])
in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit
b := new(big.Int).SetBytes(swapEndianness(in[:]))
fe1.FromBig(b)
if fe != fe1 {
return false
}
fe.Bytes(out[:0])
buf := make([]byte, 32) // pad with zeroes
copy(buf, swapEndianness(fe1.ToBig().Bytes()))
return bytes.Equal(out[:], buf) && isInBounds(&fe) && isInBounds(&fe1)
}
if err := quick.Check(f1, nil); err != nil {
t.Error(err)
}
}
func TestFromBytesRoundTripEdgeCases(t *testing.T) {
// TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1,
// and between 2^255 and 2^256-1. Test both the documented FromBytes
// behavior, and that Bytes reduces them.
}
// Tests self-consistency between FeMul and FeSquare.
func TestSanity(t *testing.T) {
var x FieldElement
var x2, x2sq FieldElement
// var x2Go, x2sqGo FieldElement
x = [5]uint64{1, 1, 1, 1, 1}
x2.Mul(&x, &x)
// FeMulGo(&x2Go, &x, &x)
x2sq.Square(&x)
// FeSquareGo(&x2sqGo, &x)
// if !vartimeEqual(x2, x2Go) || !vartimeEqual(x2sq, x2sqGo) || !vartimeEqual(x2, x2sq) {
// t.Fatalf("all ones failed\nmul.s: %d\nmul.g: %d\nsqr.s: %d\nsqr.g: %d\n", x2, x2Go, x2sq, x2sqGo)
// }
if x2 != x2sq {
t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
}
var bytes [32]byte
_, err := io.ReadFull(rand.Reader, bytes[:])
if err != nil {
t.Fatal(err)
}
x.FromBytes(bytes[:])
x2.Mul(&x, &x)
// FeMulGo(&x2Go, &x, &x)
x2sq.Square(&x)
// FeSquareGo(&x2sqGo, &x)
// if !vartimeEqual(x2, x2Go) || !vartimeEqual(x2sq, x2sqGo) || !vartimeEqual(x2, x2sq) {
// t.Fatalf("random field element failed\nfe: %x\n\nmul.s: %x\nmul.g: %x\nsqr.s: %x\nsqr.g: %x\n", x, x2, x2Go, x2sq, x2sqGo)
// }
if x2 != x2sq {
t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
}
}
func TestEqual(t *testing.T) {
var x FieldElement = [5]uint64{1, 1, 1, 1, 1}
var y FieldElement = [5]uint64{5, 4, 3, 2, 1}
eq := x.Equal(&x)
if eq != 1 {
t.Errorf("wrong about equality")
}
eq = x.Equal(&y)
if eq != 0 {
t.Errorf("wrong about inequality")
}
}
func TestInvert(t *testing.T) {
var x FieldElement = [5]uint64{1, 1, 1, 1, 1}
var one FieldElement = [5]uint64{1, 0, 0, 0, 0}
var xinv, r FieldElement
xinv.Invert(&x)
r.Mul(&x, &xinv)
r.reduce()
if one != r {
t.Errorf("inversion identity failed, got: %x", r)
}
var bytes [32]byte
_, err := io.ReadFull(rand.Reader, bytes[:])
if err != nil {
t.Fatal(err)
}
x.FromBytes(bytes[:])
xinv.Invert(&x)
r.Mul(&x, &xinv)
r.reduce()
if one != r {
t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
}
}
func TestSelectSwap(t *testing.T) {
a := FieldElement([5]uint64{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676})
b := FieldElement([5]uint64{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972})
var c, d FieldElement
c.Select(&a, &b, 1)
d.Select(&a, &b, 0)
if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
t.Errorf("Select failed")
}
CondSwap(&c, &d, 0)
if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
t.Errorf("Swap failed")
}
CondSwap(&c, &d, 1)
if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
t.Errorf("Swap failed")
}
}
func TestMul32(t *testing.T) {
isAlmostInBounds := func(x *FieldElement) bool {
return bits.Len64(x[0]) <= 52 &&
bits.Len64(x[1]) <= 52 &&
bits.Len64(x[2]) <= 52 &&
bits.Len64(x[3]) <= 52 &&
bits.Len64(x[4]) <= 52
}
mul32EquivalentToMul := func(x FieldElement, y uint32) bool {
t1 := new(FieldElement)
for i := 0; i < 100; i++ {
t1.Mul32(&x, y)
}
ty := new(FieldElement)
ty[0] = uint64(y)
t2 := new(FieldElement)
for i := 0; i < 100; i++ {
t2.Mul(&x, ty)
}
return t1.Equal(t2) == 1 && isAlmostInBounds(t1) && isInBounds(t2)
}
if err := quick.Check(mul32EquivalentToMul, quickCheckConfig); err != nil {
t.Error(err)
}
}

View File

@ -1,28 +0,0 @@
// Copyright (c) 2019 George Tankersley. All rights reserved.
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.13
package radix51
import "math/bits"
// madd64 returns ol + oh * 2⁶⁴ = lo + hi * 2⁶⁴ + a * b. That is, it multiplies
// a and b, and adds the result to the split uint128 [lo,hi].
func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
oh, ol = bits.Mul64(a, b)
var c uint64
ol, c = bits.Add64(ol, lo, 0)
oh, _ = bits.Add64(oh, hi, c)
return
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}

View File

@ -1,37 +0,0 @@
// Copyright (c) 2017 George Tankersley. All rights reserved.
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.13
package radix51
import "unsafe"
// madd64 returns ol + oh * 2⁶⁴ = lo + hi * 2⁶⁴ + a * b. That is, it multiplies
// a and b, and adds the result to the split uint128 [lo,hi].
func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32)
t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF)
ol = (a * b) + lo
cmp := ol < lo
oh = hi + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp)))
return
}
const mask32 = 1<<32 - 1
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
w0 := (a & mask32) * uint64(b)
t := (a>>32)*uint64(b) + w0>>32
w1 := t & mask32
w2 := t >> 32
mh := w2 + w1>>32
ml := a * uint64(b)
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}

File diff suppressed because it is too large Load Diff

View File

@ -1,133 +0,0 @@
// Copyright 2019 Henry de Valence. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package scalar
import (
"bytes"
"math/big"
"testing"
"testing/quick"
)
// quickCheckConfig will make each quickcheck test run (1024 * -quickchecks)
// times. The default value of -quickchecks is 100.
var quickCheckConfig = &quick.Config{MaxCountScale: 1 << 10}
func TestFromBytesRoundTrip(t *testing.T) {
f1 := func(in, out [32]byte, sc Scalar) bool {
in[len(in)-1] &= (1 << 4) - 1 // Mask out top 4 bits for 252-bit numbers
if err := sc.FromCanonicalBytes(in[:]); err != nil {
return false
}
sc.Bytes(out[:0])
return bytes.Equal(in[:], out[:]) && scMinimal(sc[:])
}
if err := quick.Check(f1, nil); err != nil {
t.Errorf("failed bytes->scalar->bytes round-trip: %v", err)
}
f2 := func(sc1, sc2 Scalar, out [32]byte) bool {
sc1.Bytes(out[:0])
if err := sc2.FromCanonicalBytes(out[:]); err != nil {
return false
}
sc1.reduce()
sc2.reduce()
return sc1 == sc2
}
if err := quick.Check(f2, nil); err != nil {
t.Errorf("failed scalar->bytes->scalar round-trip: %v", err)
}
}
func TestFromUniformBytes(t *testing.T) {
mod, _ := new(big.Int).SetString("27742317777372353535851937790883648493", 10)
mod.Add(mod, new(big.Int).Lsh(big.NewInt(1), 252))
f := func(in [64]byte, sc Scalar) bool {
sc.FromUniformBytes(in[:])
if !scMinimal(sc[:]) {
return false
}
b := sc.Bytes(nil)
byteSwap(b) // convert to big endian for SetBytes
scBig := new(big.Int).SetBytes(b)
byteSwap(in[:]) // convert to big endian for SetBytes
inBig := new(big.Int).SetBytes(in[:])
return inBig.Mod(inBig, mod).Cmp(scBig) == 0
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
func byteSwap(b []byte) {
for i := range b[:len(b)/2] {
b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i]
}
}
func TestMulDistributesOverAdd(t *testing.T) {
mulDistributesOverAdd := func(x, y, z Scalar) bool {
// Compute t1 = (x+y)*z
var t1 Scalar
t1.Add(&x, &y)
t1.Mul(&t1, &z)
// Compute t2 = x*z + y*z
var t2 Scalar
var t3 Scalar
t2.Mul(&x, &z)
t3.Mul(&y, &z)
t2.Add(&t2, &t3)
return t1.Equal(&t2) == 1 && scMinimal(t1[:]) && scMinimal(t2[:])
}
if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil {
t.Error(err)
}
}
func TestNonAdjacentForm(t *testing.T) {
s := Scalar([32]byte{
0x1a, 0x0e, 0x97, 0x8a, 0x90, 0xf6, 0x62, 0x2d,
0x37, 0x47, 0x02, 0x3f, 0x8a, 0xd8, 0x26, 0x4d,
0xa7, 0x58, 0xaa, 0x1b, 0x88, 0xe0, 0x40, 0xd1,
0x58, 0x9e, 0x7b, 0x7f, 0x23, 0x76, 0xef, 0x09,
})
expectedNaf := [256]int8{
0, 13, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, -9, 0, 0, 0, 0, -11, 0, 0, 0, 0, 3, 0, 0, 0, 0, 1,
0, 0, 0, 0, 9, 0, 0, 0, 0, -5, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 11, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0,
-9, 0, 0, 0, 0, 0, -3, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 9, 0,
0, 0, 0, -15, 0, 0, 0, 0, -7, 0, 0, 0, 0, -9, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, -3, 0,
0, 0, 0, -11, 0, 0, 0, 0, -7, 0, 0, 0, 0, -13, 0, 0, 0, 0, 11, 0, 0, 0, 0, -9, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, -15, 0, 0, 0, 0, 1, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 13, 0, 0, 0,
0, 0, 0, 11, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, -9, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 7,
0, 0, 0, 0, 0, -15, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 15, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
}
sNaf := s.NonAdjacentForm(5)
for i := 0; i < 256; i++ {
if expectedNaf[i] != sNaf[i] {
t.Errorf("Wrong digit at position %d, got %d, expected %d", i, sNaf[i], expectedNaf[i])
}
}
}
func TestInvert(t *testing.T) {
invertWorks := func(x Scalar) bool {
var xInv, check Scalar
xInv.Inv(&x)
check.Mul(&x, &xInv)
return check.Equal(&scOne) == 1
}
if err := quick.Check(invertWorks, quickCheckConfig); err != nil {
t.Error(err)
}
}

View File

@ -17,28 +17,61 @@ import (
"encoding/base64"
"errors"
"github.com/gtank/ristretto255/internal/edwards25519"
"github.com/gtank/ristretto255/internal/radix51"
"github.com/gtank/ristretto255/internal/scalar"
"filippo.io/edwards25519"
"filippo.io/edwards25519/field"
)
// Constants from draft-hdevalence-cfrg-ristretto-01, Section 3.1.
// Constants from draft-hdevalence-cfrg-ristretto-01, Section 3.1. See
// TestConstants for their decimal values.
var (
sqrtM1 = fieldElementFromDecimal(
"19681161376707505956807079304988542015446066515923890162744021073123829784752")
sqrtADMinusOne = fieldElementFromDecimal(
"25063068953384623474111414158702152701244531502492656460079210482610430750235")
invSqrtAMinusD = fieldElementFromDecimal(
"54469307008909316920995813868745141605393597292927456921205312896311721017578")
oneMinusDSQ = fieldElementFromDecimal(
"1159843021668779879193775521855586647937357759715417654439879720876111806838")
dMinusOneSQ = fieldElementFromDecimal(
"40440834346308536858101042469323190826248399146238708352240133220865137265952")
d, _ = new(field.Element).SetBytes([]byte{
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52,
})
sqrtM1, _ = new(field.Element).SetBytes([]byte{
0xb0, 0xa0, 0x0e, 0x4a, 0x27, 0x1b, 0xee, 0xc4,
0x78, 0xe4, 0x2f, 0xad, 0x06, 0x18, 0x43, 0x2f,
0xa7, 0xd7, 0xfb, 0x3d, 0x99, 0x00, 0x4d, 0x2b,
0x0b, 0xdf, 0xc1, 0x4f, 0x80, 0x24, 0x83, 0x2b,
})
sqrtADMinusOne, _ = new(field.Element).SetBytes([]byte{
0x1b, 0x2e, 0x7b, 0x49, 0xa0, 0xf6, 0x97, 0x7e,
0xbd, 0x54, 0x78, 0x1b, 0x0c, 0x8e, 0x9d, 0xaf,
0xfd, 0xd1, 0xf5, 0x31, 0xc9, 0xfc, 0x3c, 0x0f,
0xac, 0x48, 0x83, 0x2b, 0xbf, 0x31, 0x69, 0x37,
})
invSqrtAMinusD, _ = new(field.Element).SetBytes([]byte{
0xea, 0x40, 0x5d, 0x80, 0xaa, 0xfd, 0xc8, 0x99,
0xbe, 0x72, 0x41, 0x5a, 0x17, 0x16, 0x2f, 0x9d,
0x40, 0xd8, 0x01, 0xfe, 0x91, 0x7b, 0xc2, 0x16,
0xa2, 0xfc, 0xaf, 0xcf, 0x05, 0x89, 0x6c, 0x78,
})
oneMinusDSQ, _ = new(field.Element).SetBytes([]byte{
0x76, 0xc1, 0x5f, 0x94, 0xc1, 0x09, 0x7c, 0xe2,
0x0f, 0x35, 0x5e, 0xcd, 0x38, 0xa1, 0x81, 0x2c,
0xe4, 0xdf, 0x70, 0xbe, 0xdd, 0xab, 0x94, 0x99,
0xd7, 0xe0, 0xb3, 0xb2, 0xa8, 0x72, 0x90, 0x02,
})
dMinusOneSQ, _ = new(field.Element).SetBytes([]byte{
0x20, 0x4d, 0xed, 0x44, 0xaa, 0x5a, 0xad, 0x31,
0x99, 0x19, 0x1e, 0xb0, 0x2c, 0x4a, 0x9e, 0xd2,
0xeb, 0x4e, 0x9b, 0x52, 0x2f, 0xd3, 0xdc, 0x4c,
0x41, 0x22, 0x6c, 0xf6, 0x7a, 0xb3, 0x68, 0x59,
})
)
var (
zero = new(field.Element)
one = new(field.Element).One()
two = new(field.Element).Add(one, one)
minusOne = new(field.Element).Subtract(zero, one)
)
// Element is an element of the ristretto255 prime-order group.
type Element struct {
r edwards25519.ProjP3
r edwards25519.Point
}
// NewElement returns a new Element set to the identity value.
@ -51,14 +84,14 @@ func NewElement() *Element {
// NewIdentityElement returns a new Element set to the identity value.
func NewIdentityElement() *Element {
e := &Element{}
e.r.Zero()
e.r.Set(edwards25519.NewIdentityPoint())
return e
}
// NewGeneratorElement returns a new Element set to the canonical generator.
func NewGeneratorElement() *Element {
e := &Element{}
e.r.Set(&edwards25519.B)
e.r.Set(edwards25519.NewGeneratorPoint())
return e
}
@ -72,14 +105,17 @@ func (e *Element) Set(x *Element) *Element {
//
// Note that Elements must not be compared in any other way.
func (e *Element) Equal(ee *Element) int {
var f0, f1 radix51.FieldElement
X1, Y1, _, _ := e.r.ExtendedCoordinates()
X2, Y2, _, _ := ee.r.ExtendedCoordinates()
f0.Mul(&e.r.X, &ee.r.Y) // x1 * y2
f1.Mul(&e.r.Y, &ee.r.X) // y1 * x2
var f0, f1 field.Element
f0.Multiply(X1, Y2) // x1 * y2
f1.Multiply(Y1, X2) // y1 * x2
out := f0.Equal(&f1)
f0.Mul(&e.r.Y, &ee.r.Y) // y1 * y2
f1.Mul(&e.r.X, &ee.r.X) // x1 * x2
f0.Multiply(Y1, Y2) // y1 * y2
f1.Multiply(X1, X2) // x1 * x2
out = out | f0.Equal(&f1)
return out
@ -106,13 +142,13 @@ func (e *Element) SetUniformBytes(b []byte) (*Element, error) {
return nil, errors.New("ristretto255: SetUniformBytes input is not 64 bytes long")
}
f := &radix51.FieldElement{}
f := &field.Element{}
f.FromBytes(b[:32])
f.SetBytes(b[:32])
point1 := &Element{}
mapToPoint(&point1.r, f)
f.FromBytes(b[32:])
f.SetBytes(b[32:])
point2 := &Element{}
mapToPoint(&point2.r, f)
@ -120,32 +156,32 @@ func (e *Element) SetUniformBytes(b []byte) (*Element, error) {
}
// mapToPoint implements MAP from Section 3.2.4 of draft-hdevalence-cfrg-ristretto-00.
func mapToPoint(out *edwards25519.ProjP3, t *radix51.FieldElement) {
func mapToPoint(out *edwards25519.Point, t *field.Element) {
// r = SQRT_M1 * t^2
r := &radix51.FieldElement{}
r.Mul(sqrtM1, r.Square(t))
r := &field.Element{}
r.Multiply(sqrtM1, r.Square(t))
// u = (r + 1) * ONE_MINUS_D_SQ
u := &radix51.FieldElement{}
u.Mul(u.Add(r, radix51.One), oneMinusDSQ)
u := &field.Element{}
u.Multiply(u.Add(r, one), oneMinusDSQ)
// c = -1
c := &radix51.FieldElement{}
c.Set(radix51.MinusOne)
c := &field.Element{}
c.Set(minusOne)
// v = (c - r*D) * (r + D)
rPlusD := &radix51.FieldElement{}
rPlusD.Add(r, edwards25519.D)
v := &radix51.FieldElement{}
v.Mul(v.Sub(c, v.Mul(r, edwards25519.D)), rPlusD)
rPlusD := &field.Element{}
rPlusD.Add(r, d)
v := &field.Element{}
v.Multiply(v.Subtract(c, v.Multiply(r, d)), rPlusD)
// (was_square, s) = SQRT_RATIO_M1(u, v)
s := &radix51.FieldElement{}
wasSquare := feSqrtRatio(s, u, v)
s := &field.Element{}
_, wasSquare := s.SqrtRatio(u, v)
// s_prime = -CT_ABS(s*t)
sPrime := &radix51.FieldElement{}
sPrime.Neg(sPrime.Abs(sPrime.Mul(s, t)))
sPrime := &field.Element{}
sPrime.Negate(sPrime.Absolute(sPrime.Multiply(s, t)))
// s = CT_SELECT(s IF was_square ELSE s_prime)
s.Select(s, sPrime, wasSquare)
@ -153,31 +189,35 @@ func mapToPoint(out *edwards25519.ProjP3, t *radix51.FieldElement) {
c.Select(c, r, wasSquare)
// N = c * (r - 1) * D_MINUS_ONE_SQ - v
N := &radix51.FieldElement{}
N.Mul(c, N.Sub(r, radix51.One))
N.Sub(N.Mul(N, dMinusOneSQ), v)
N := &field.Element{}
N.Multiply(c, N.Subtract(r, one))
N.Subtract(N.Multiply(N, dMinusOneSQ), v)
s2 := &radix51.FieldElement{}
s2 := &field.Element{}
s2.Square(s)
// w0 = 2 * s * v
w0 := &radix51.FieldElement{}
w0.Add(w0, w0.Mul(s, v))
w0 := &field.Element{}
w0.Add(w0, w0.Multiply(s, v))
// w1 = N * SQRT_AD_MINUS_ONE
w1 := &radix51.FieldElement{}
w1.Mul(N, sqrtADMinusOne)
w1 := &field.Element{}
w1.Multiply(N, sqrtADMinusOne)
// w2 = 1 - s^2
w2 := &radix51.FieldElement{}
w2.Sub(radix51.One, s2)
w2 := &field.Element{}
w2.Subtract(one, s2)
// w3 = 1 + s^2
w3 := &radix51.FieldElement{}
w3.Add(radix51.One, s2)
w3 := &field.Element{}
w3.Add(one, s2)
// return (w0*w3, w2*w1, w1*w3, w0*w2)
out.X.Mul(w0, w3)
out.Y.Mul(w2, w1)
out.Z.Mul(w1, w3)
out.T.Mul(w0, w2)
var X, Y, Z, T field.Element
X.Multiply(w0, w3)
Y.Multiply(w2, w1)
Z.Multiply(w1, w3)
T.Multiply(w0, w2)
if _, err := out.SetExtendedCoordinates(&X, &Y, &Z, &T); err != nil {
panic("ristretto255: internal error: MAP generated invalid coordinates")
}
}
// Encode appends the 32 bytes canonical encoding of e to b
@ -185,75 +225,95 @@ func mapToPoint(out *edwards25519.ProjP3, t *radix51.FieldElement) {
//
// Deprecated: use Bytes. This API will be removed before v1.0.0.
func (e *Element) Encode(b []byte) []byte {
return e.bytes(b)
ret, out := sliceForAppend(b, 32)
e.bytes(out)
return ret
}
// sliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// Bytes returns the 32 bytes canonical encoding of e.
func (e *Element) Bytes() []byte {
// Bytes is outlined to let the allocation happen on the stack of the caller.
b := make([]byte, 0, 32)
b := make([]byte, 32)
return e.bytes(b)
}
func (e *Element) bytes(b []byte) []byte {
tmp := &radix51.FieldElement{}
X, Y, Z, T := e.r.ExtendedCoordinates()
tmp := &field.Element{}
// u1 = (z0 + y0) * (z0 - y0)
u1 := &radix51.FieldElement{}
u1.Add(&e.r.Z, &e.r.Y).Mul(u1, tmp.Sub(&e.r.Z, &e.r.Y))
u1 := &field.Element{}
u1.Add(Z, Y).Multiply(u1, tmp.Subtract(Z, Y))
// u2 = x0 * y0
u2 := &radix51.FieldElement{}
u2.Mul(&e.r.X, &e.r.Y)
u2 := &field.Element{}
u2.Multiply(X, Y)
// Ignore was_square since this is always square
// (_, invsqrt) = SQRT_RATIO_M1(1, u1 * u2^2)
invSqrt := &radix51.FieldElement{}
feSqrtRatio(invSqrt, radix51.One, tmp.Square(u2).Mul(tmp, u1))
invSqrt := &field.Element{}
invSqrt.SqrtRatio(one, tmp.Square(u2).Multiply(tmp, u1))
// den1 = invsqrt * u1
// den2 = invsqrt * u2
den1, den2 := &radix51.FieldElement{}, &radix51.FieldElement{}
den1.Mul(invSqrt, u1)
den2.Mul(invSqrt, u2)
den1, den2 := &field.Element{}, &field.Element{}
den1.Multiply(invSqrt, u1)
den2.Multiply(invSqrt, u2)
// z_inv = den1 * den2 * t0
zInv := &radix51.FieldElement{}
zInv.Mul(den1, den2).Mul(zInv, &e.r.T)
zInv := &field.Element{}
zInv.Multiply(den1, den2).Multiply(zInv, T)
// ix0 = x0 * SQRT_M1
// iy0 = y0 * SQRT_M1
ix0, iy0 := &radix51.FieldElement{}, &radix51.FieldElement{}
ix0.Mul(&e.r.X, sqrtM1)
iy0.Mul(&e.r.Y, sqrtM1)
ix0, iy0 := &field.Element{}, &field.Element{}
ix0.Multiply(X, sqrtM1)
iy0.Multiply(Y, sqrtM1)
// enchanted_denominator = den1 * INVSQRT_A_MINUS_D
enchantedDenominator := &radix51.FieldElement{}
enchantedDenominator.Mul(den1, invSqrtAMinusD)
enchantedDenominator := &field.Element{}
enchantedDenominator.Multiply(den1, invSqrtAMinusD)
// rotate = IS_NEGATIVE(t0 * z_inv)
rotate := tmp.Mul(&e.r.T, zInv).IsNegative()
rotate := tmp.Multiply(T, zInv).IsNegative()
// x = CT_SELECT(iy0 IF rotate ELSE x0)
// y = CT_SELECT(ix0 IF rotate ELSE y0)
x, y := &radix51.FieldElement{}, &radix51.FieldElement{}
x.Select(iy0, &e.r.X, rotate)
y.Select(ix0, &e.r.Y, rotate)
x, y := &field.Element{}, &field.Element{}
x.Select(iy0, X, rotate)
y.Select(ix0, Y, rotate)
// z = z0
z := &e.r.Z
z := Z
// den_inv = CT_SELECT(enchanted_denominator IF rotate ELSE den2)
denInv := &radix51.FieldElement{}
denInv := &field.Element{}
denInv.Select(enchantedDenominator, den2, rotate)
// y = CT_NEG(y, IS_NEGATIVE(x * z_inv))
y.CondNeg(y, tmp.Mul(x, zInv).IsNegative())
isNegative := tmp.Multiply(x, zInv).IsNegative()
y.Select(tmp.Negate(y), y, isNegative)
// s = CT_ABS(den_inv * (z - y))
s := tmp.Sub(z, y).Mul(tmp, denInv).Abs(tmp)
s := tmp.Subtract(z, y).Multiply(tmp, denInv).Absolute(tmp)
// Return the canonical little-endian encoding of s.
return s.Bytes(b)
copy(b, s.Bytes())
return b
}
var errInvalidEncoding = errors.New("invalid Ristretto encoding")
var errInvalidEncoding = errors.New("ristretto255: invalid element encoding")
// Decode sets e to the decoded value of in. If in is not a 32 byte canonical
// encoding, Decode returns an error, and the receiver is unchanged.
@ -273,12 +333,11 @@ func (e *Element) SetCanonicalBytes(in []byte) (*Element, error) {
}
// First, interpret the string as an integer s in little-endian representation.
s := &radix51.FieldElement{}
s.FromBytes(in)
s := &field.Element{}
s.SetBytes(in)
// If the resulting value is >= p, decoding fails.
var buf [32]byte
if !bytes.Equal(s.Bytes(buf[:0]), in) {
if !bytes.Equal(s.Bytes(), in) {
return nil, errInvalidEncoding
}
@ -288,63 +347,65 @@ func (e *Element) SetCanonicalBytes(in []byte) (*Element, error) {
}
// ss = s^2
sSqr := &radix51.FieldElement{}
sSqr := &field.Element{}
sSqr.Square(s)
// u1 = 1 - ss
u1 := &radix51.FieldElement{}
u1.Sub(radix51.One, sSqr)
u1 := &field.Element{}
u1.Subtract(one, sSqr)
// u2 = 1 + ss
u2 := &radix51.FieldElement{}
u2.Add(radix51.One, sSqr)
u2 := &field.Element{}
u2.Add(one, sSqr)
// u2_sqr = u2^2
u2Sqr := &radix51.FieldElement{}
u2Sqr := &field.Element{}
u2Sqr.Square(u2)
// v = -(D * u1^2) - u2_sqr
v := &radix51.FieldElement{}
v.Square(u1).Mul(v, edwards25519.D).Neg(v).Sub(v, u2Sqr)
v := &field.Element{}
v.Square(u1).Multiply(v, d).Negate(v).Subtract(v, u2Sqr)
// (was_square, invsqrt) = SQRT_RATIO_M1(1, v * u2_sqr)
invSqrt, tmp := &radix51.FieldElement{}, &radix51.FieldElement{}
wasSquare := feSqrtRatio(invSqrt, radix51.One, tmp.Mul(v, u2Sqr))
invSqrt, tmp := &field.Element{}, &field.Element{}
_, wasSquare := invSqrt.SqrtRatio(one, tmp.Multiply(v, u2Sqr))
// den_x = invsqrt * u2
// den_y = invsqrt * den_x * v
denX, denY := &radix51.FieldElement{}, &radix51.FieldElement{}
denX.Mul(invSqrt, u2)
denY.Mul(invSqrt, denX).Mul(denY, v)
denX, denY := &field.Element{}, &field.Element{}
denX.Multiply(invSqrt, u2)
denY.Multiply(invSqrt, denX).Multiply(denY, v)
// x = CT_ABS(2 * s * den_x)
// y = u1 * den_y
// t = x * y
var out edwards25519.ProjP3
out.X.Mul(radix51.Two, s).Mul(&out.X, denX).Abs(&out.X)
out.Y.Mul(u1, denY)
out.Z.One()
out.T.Mul(&out.X, &out.Y)
var X, Y, Z, T field.Element
X.Multiply(two, s).Multiply(&X, denX).Absolute(&X)
Y.Multiply(u1, denY)
Z.One()
T.Multiply(&X, &Y)
// If was_square is FALSE, or IS_NEGATIVE(t) returns TRUE, or y = 0, decoding fails.
if wasSquare == 0 || out.T.IsNegative() == 1 || out.Y.Equal(radix51.Zero) == 1 {
if wasSquare == 0 || T.IsNegative() == 1 || Y.Equal(zero) == 1 {
return nil, errInvalidEncoding
}
// Otherwise, return the internal representation in extended coordinates (x, y, 1, t).
e.r.Set(&out)
if _, err := e.r.SetExtendedCoordinates(&X, &Y, &Z, &T); err != nil {
panic("ristretto255: internal error: DECODE generated invalid coordinates")
}
return e, nil
}
// ScalarBaseMult sets e = s * B, where B is the canonical generator, and returns e.
func (e *Element) ScalarBaseMult(s *Scalar) *Element {
e.r.BasepointMul(&s.s)
e.r.ScalarBaseMult(&s.s)
return e
}
// ScalarMult sets e = s * p, and returns e.
func (e *Element) ScalarMult(s *Scalar, p *Element) *Element {
e.r.ScalarMul(&s.s, &p.r)
e.r.ScalarMult(&s.s, &p.r)
return e
}
@ -355,13 +416,13 @@ func (e *Element) MultiScalarMult(s []*Scalar, p []*Element) *Element {
if len(p) != len(s) {
panic("ristretto255: MultiScalarMult invoked with mismatched slice lengths")
}
points := make([]*edwards25519.ProjP3, len(p))
scalars := make([]scalar.Scalar, len(s))
points := make([]*edwards25519.Point, len(p))
scalars := make([]*edwards25519.Scalar, len(s))
for i := range s {
points[i] = &p[i].r
scalars[i] = s[i].s
scalars[i] = &s[i].s
}
e.r.MultiscalarMul(scalars, points)
e.r.MultiScalarMult(scalars, points)
return e
}
@ -370,15 +431,15 @@ func (e *Element) MultiScalarMult(s []*Scalar, p []*Element) *Element {
// Execution time depends on the inputs.
func (e *Element) VarTimeMultiScalarMult(s []*Scalar, p []*Element) *Element {
if len(p) != len(s) {
panic("ristretto255: MultiScalarMult invoked with mismatched slice lengths")
panic("ristretto255: VarTimeMultiScalarMult invoked with mismatched slice lengths")
}
points := make([]*edwards25519.ProjP3, len(p))
scalars := make([]scalar.Scalar, len(s))
points := make([]*edwards25519.Point, len(p))
scalars := make([]*edwards25519.Scalar, len(s))
for i := range s {
points[i] = &p[i].r
scalars[i] = s[i].s
scalars[i] = &s[i].s
}
e.r.VartimeMultiscalarMul(scalars, points)
e.r.VarTimeMultiScalarMult(scalars, points)
return e
}
@ -387,7 +448,7 @@ func (e *Element) VarTimeMultiScalarMult(s []*Scalar, p []*Element) *Element {
//
// Execution time depends on the inputs.
func (e *Element) VarTimeDoubleScalarBaseMult(a *Scalar, A *Element, b *Scalar) *Element {
e.r.VartimeDoubleBaseMul(&a.s, &A.r, &b.s)
e.r.VarTimeDoubleScalarBaseMult(&a.s, &A.r, &b.s)
return e
}
@ -399,13 +460,13 @@ func (e *Element) Add(p, q *Element) *Element {
// Subtract sets e = p - q, and returns e.
func (e *Element) Subtract(p, q *Element) *Element {
e.r.Sub(&p.r, &q.r)
e.r.Subtract(&p.r, &q.r)
return e
}
// Negate sets e = -p, and returns e.
func (e *Element) Negate(p *Element) *Element {
e.r.Neg(&p.r)
e.r.Negate(&p.r)
return e
}

View File

@ -5,68 +5,12 @@ import (
"crypto/sha512"
"encoding/hex"
"encoding/json"
"math/big"
"testing"
"github.com/gtank/ristretto255/internal/radix51"
"filippo.io/edwards25519/field"
)
func assertFeEqual(value, expect *radix51.FieldElement) {
if value.Equal(expect) == 1 {
return
} else {
panic("failed equality assertion")
}
}
type sqrtRatioTest struct {
u, v *radix51.FieldElement
sqrt *radix51.FieldElement
choice int
negative int
}
func TestSqrtRatioM1(t *testing.T) {
// These tests can be found in curve25519-dalek's 'field.rs'
var (
zero, one = radix51.Zero, radix51.One
// Two is nonsquare in our field, 4 is square
two = new(radix51.FieldElement).Add(one, one)
four = new(radix51.FieldElement).Add(two, two)
// 2*i
twoTimesSqrtM1 = new(radix51.FieldElement).Mul(two, sqrtM1)
sqrt2i = fieldElementFromDecimal(
"38214883241950591754978413199355411911188925816896391856984770930832735035196")
invSqrt4 = fieldElementFromDecimal(
"28948022309329048855892746252171976963317496166410141009864396001978282409974")
)
// Check the construction of those magic numbers.
assertFeEqual(new(radix51.FieldElement).Mul(sqrt2i, sqrt2i), twoTimesSqrtM1)
assertFeEqual(new(radix51.FieldElement).Mul(new(radix51.FieldElement).Square(invSqrt4), four), one)
var tests = []sqrtRatioTest{
{u: zero, v: zero, sqrt: zero, choice: 1, negative: 0},
{u: zero, v: one, sqrt: zero, choice: 1, negative: 0},
{u: one, v: zero, sqrt: zero, choice: 0, negative: 0},
{u: two, v: one, sqrt: sqrt2i, choice: 0, negative: 0},
{u: four, v: one, sqrt: two, choice: 1, negative: 0},
{u: one, v: four, sqrt: invSqrt4, choice: 1, negative: 0},
}
for idx, tt := range tests {
sqrt := new(radix51.FieldElement)
choice := feSqrtRatio(sqrt, tt.u, tt.v)
if choice != tt.choice || sqrt.Equal(tt.sqrt) != 1 || sqrt.IsNegative() != tt.negative {
t.Errorf("Failed test %d", idx)
t.Logf("Got {u: %v, v: %v, sqrt: %v, choice: %d, neg: %d}", tt.u, tt.v, sqrt, choice, sqrt.IsNegative())
}
}
}
// The encoding of the canonical generator.
var compressedRistrettoBasepoint, _ = hex.DecodeString("e2f2ae0a6abc4e71a884a961c500515f58e30b6aa582dd8db6a65945e08d2d76")
@ -274,7 +218,7 @@ func TestMarshalScalar(t *testing.T) {
x := new(Scalar)
// generate an arbitrary scalar
xbytes := sha512.Sum512([]byte("Hello World"))
x.FromUniformBytes(xbytes[:])
x.SetUniformBytes(xbytes[:])
text, err := json.Marshal(x)
if err != nil {
t.Fatalf("Could not marshal json: %v", err)
@ -336,10 +280,10 @@ func TestScalarSet(t *testing.T) {
scOne := make([]byte, 32)
scOne[0] = 0x01
sc1, sc2 := NewScalar(), NewScalar().Zero()
sc1, sc2 := NewScalar(), NewScalar()
// sc1 <- 1
sc1.Decode(scOne)
sc1.SetCanonicalBytes(scOne)
// 1 != 0
if sc1.Equal(sc2) == 1 {
@ -362,3 +306,45 @@ func TestScalarSet(t *testing.T) {
t.Error("shouldn't have changed")
}
}
func TestConstants(t *testing.T) {
// From draft-hdevalence-cfrg-ristretto-01, Section 3.1.
t.Run("d", func(t *testing.T) {
testConstant(t, d,
"37095705934669439343138083508754565189542113879843219016388785533085940283555")
})
t.Run("sqrtM1", func(t *testing.T) {
testConstant(t, sqrtM1,
"19681161376707505956807079304988542015446066515923890162744021073123829784752")
})
t.Run("sqrtADMinusOne", func(t *testing.T) {
testConstant(t, sqrtADMinusOne,
"25063068953384623474111414158702152701244531502492656460079210482610430750235")
})
t.Run("invSqrtAMinusD", func(t *testing.T) {
testConstant(t, invSqrtAMinusD,
"54469307008909316920995813868745141605393597292927456921205312896311721017578")
})
t.Run("oneMinusDSQ", func(t *testing.T) {
testConstant(t, oneMinusDSQ,
"1159843021668779879193775521855586647937357759715417654439879720876111806838")
})
t.Run("dMinusOneSQ", func(t *testing.T) {
testConstant(t, dMinusOneSQ,
"40440834346308536858101042469323190826248399146238708352240133220865137265952")
})
}
func testConstant(t *testing.T, f *field.Element, decimal string) {
b, ok := new(big.Int).SetString(decimal, 10)
if !ok {
t.Fatal("invalid decimal")
}
buf := b.FillBytes(make([]byte, 32))
for i := 0; i < len(buf)/2; i++ {
buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
}
if !bytes.Equal(f.Bytes(), buf) {
t.Errorf("expected %x", buf)
}
}

View File

@ -8,20 +8,22 @@ import (
"encoding/base64"
"errors"
"github.com/gtank/ristretto255/internal/scalar"
"filippo.io/edwards25519"
)
// A Scalar is an element of the ristretto255 scalar field, as specified in
// draft-hdevalence-cfrg-ristretto-01, Section 3.4. That is, an integer modulo
//
// l = 2^252 + 27742317777372353535851937790883648493
//
// The zero value is a valid zero element.
type Scalar struct {
s scalar.Scalar
s edwards25519.Scalar
}
// NewScalar returns a Scalar set to the value 0.
func NewScalar() *Scalar {
return (&Scalar{}).Zero()
return &Scalar{}
}
// Set sets the value of s to x and returns s.
@ -38,19 +40,19 @@ func (s *Scalar) Add(x, y *Scalar) *Scalar {
// Subtract sets s = x - y mod l and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
s.s.Sub(&x.s, &y.s)
s.s.Subtract(&x.s, &y.s)
return s
}
// Negate sets s = -x mod l and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
s.s.Neg(&x.s)
s.s.Negate(&x.s)
return s
}
// Multiply sets s = x * y mod l and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
s.s.Mul(&x.s, &y.s)
s.s.Multiply(&x.s, &y.s)
return s
}
@ -58,7 +60,7 @@ func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
//
// If x is 0, the result is undefined.
func (s *Scalar) Invert(x *Scalar) *Scalar {
s.s.Inv(&x.s)
s.s.Invert(&x.s)
return s
}
@ -74,12 +76,12 @@ func (s *Scalar) FromUniformBytes(x []byte) *Scalar {
}
// SetUniformBytes sets s to an uniformly distributed value given 64 uniformly
// distributed random bytes.
// distributed random bytes. If x is not of the right length, SetUniformBytes
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 {
if _, err := s.s.SetUniformBytes(x); err != nil {
return nil, errors.New("ristretto255: SetUniformBytes input is not 64 bytes long")
}
s.s.FromUniformBytes(x)
return s, nil
}
@ -97,8 +99,8 @@ func (s *Scalar) Decode(x []byte) error {
// s. If x is not a canonical encoding of s, SetCanonicalBytes returns nil and
// an error and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if err := s.s.FromCanonicalBytes(x); err != nil {
return nil, err
if _, err := s.s.SetCanonicalBytes(x); err != nil {
return nil, errors.New("ristretto255: " + err.Error())
}
return s, nil
}
@ -107,14 +109,14 @@ func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
//
// Deprecated: use Bytes. This API will be removed before v1.0.0.
func (s *Scalar) Encode(b []byte) []byte {
return s.s.Bytes(b)
ret, out := sliceForAppend(b, 32)
copy(out, s.s.Bytes())
return ret
}
// Bytes returns the 32 bytes little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
// Bytes is small, so the allocation happens on the stack of the caller.
b := make([]byte, 0, 32)
return s.s.Bytes(b)
return s.s.Bytes()
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
@ -124,7 +126,7 @@ func (s *Scalar) Equal(u *Scalar) int {
// Zero sets s = 0 and returns s.
func (s *Scalar) Zero() *Scalar {
s.s = scalar.Scalar{}
s.s = edwards25519.Scalar{}
return s
}