ristretto255: make Encode an append API

This commit is contained in:
Filippo Valsorda 2019-04-19 13:39:18 -04:00 committed by Henry de Valence
parent 118379a17a
commit 50d9eb21ae
2 changed files with 14 additions and 14 deletions

View File

@ -133,17 +133,17 @@ func mapToPoint(out *edwards25519.ExtendedGroupElement, t *radix51.FieldElement)
out.T.Mul(w0, w2) out.T.Mul(w0, w2)
} }
// Encode encodes a Ristretto group element to its canonical bytestring. // Encode appends the canonical representation of e to b and returns the result.
func (ee *Element) Encode() []byte { func (e *Element) Encode(b []byte) []byte {
tmp := &radix51.FieldElement{} tmp := &radix51.FieldElement{}
// u1 = (z0 + y0) * (z0 - y0) // u1 = (z0 + y0) * (z0 - y0)
u1 := &radix51.FieldElement{} u1 := &radix51.FieldElement{}
u1.Add(&ee.r.Z, &ee.r.Y).Mul(u1, tmp.Sub(&ee.r.Z, &ee.r.Y)) u1.Add(&e.r.Z, &e.r.Y).Mul(u1, tmp.Sub(&e.r.Z, &e.r.Y))
// u2 = x0 * y0 // u2 = x0 * y0
u2 := &radix51.FieldElement{} u2 := &radix51.FieldElement{}
u2.Mul(&ee.r.X, &ee.r.Y) u2.Mul(&e.r.X, &e.r.Y)
// Ignore was_square since this is always square // Ignore was_square since this is always square
// (_, invsqrt) = SQRT_RATIO_M1(1, u1 * u2^2) // (_, invsqrt) = SQRT_RATIO_M1(1, u1 * u2^2)
@ -157,19 +157,19 @@ func (ee *Element) Encode() []byte {
zInv := &radix51.FieldElement{} zInv := &radix51.FieldElement{}
den1.Mul(invSqrt, u1) den1.Mul(invSqrt, u1)
den2.Mul(invSqrt, u2) den2.Mul(invSqrt, u2)
zInv.Mul(den1, den2).Mul(zInv, &ee.r.T) zInv.Mul(den1, den2).Mul(zInv, &e.r.T)
// ix0 = x0 * SQRT_M1 // ix0 = x0 * SQRT_M1
// iy0 = y0 * SQRT_M1 // iy0 = y0 * SQRT_M1
// enchanted_denominator = den1 * INVSQRT_A_MINUS_D // enchanted_denominator = den1 * INVSQRT_A_MINUS_D
ix0, iy0 := &radix51.FieldElement{}, &radix51.FieldElement{} ix0, iy0 := &radix51.FieldElement{}, &radix51.FieldElement{}
enchantedDenominator := &radix51.FieldElement{} enchantedDenominator := &radix51.FieldElement{}
ix0.Mul(&ee.r.X, sqrtM1) ix0.Mul(&e.r.X, sqrtM1)
iy0.Mul(&ee.r.Y, sqrtM1) iy0.Mul(&e.r.Y, sqrtM1)
enchantedDenominator.Mul(den1, invSqrtAMinusD) enchantedDenominator.Mul(den1, invSqrtAMinusD)
// rotate = IS_NEGATIVE(t0 * z_inv) // rotate = IS_NEGATIVE(t0 * z_inv)
rotate := tmp.Mul(&ee.r.T, zInv).IsNegative() rotate := tmp.Mul(&e.r.T, zInv).IsNegative()
// x = CT_SELECT(iy0 IF rotate ELSE x0) // x = CT_SELECT(iy0 IF rotate ELSE x0)
// y = CT_SELECT(ix0 IF rotate ELSE y0) // y = CT_SELECT(ix0 IF rotate ELSE y0)
@ -177,9 +177,9 @@ func (ee *Element) Encode() []byte {
// den_inv = CT_SELECT(enchanted_denominator IF rotate ELSE den2) // den_inv = CT_SELECT(enchanted_denominator IF rotate ELSE den2)
x, y := &radix51.FieldElement{}, &radix51.FieldElement{} x, y := &radix51.FieldElement{}, &radix51.FieldElement{}
denInv := &radix51.FieldElement{} denInv := &radix51.FieldElement{}
x.Select(iy0, &ee.r.X, rotate) x.Select(iy0, &e.r.X, rotate)
y.Select(ix0, &ee.r.Y, rotate) y.Select(ix0, &e.r.Y, rotate)
z := &ee.r.Z z := &e.r.Z
denInv.Select(enchantedDenominator, den2, rotate) denInv.Select(enchantedDenominator, den2, rotate)
// y = CT_NEG(y, IS_NEGATIVE(x * z_inv)) // y = CT_NEG(y, IS_NEGATIVE(x * z_inv))
@ -189,7 +189,7 @@ func (ee *Element) Encode() []byte {
s := tmp.Mul(denInv, tmp.Sub(z, y)).Abs(tmp) s := tmp.Mul(denInv, tmp.Sub(z, y)).Abs(tmp)
// Return the canonical little-endian encoding of s. // Return the canonical little-endian encoding of s.
return s.Bytes(nil) return s.Bytes(b)
} }
// Decode decodes the canonical bytestring encoding of an element into a // Decode decodes the canonical bytestring encoding of an element into a

View File

@ -89,12 +89,12 @@ func TestRistrettoEncoding(t *testing.T) {
t.Error("decode succeeded, but got wrong point") t.Error("decode succeeded, but got wrong point")
} }
roundtripBasepoint := decodedBasepoint.Encode() roundtripBasepoint := decodedBasepoint.Encode(nil)
if !bytes.Equal(compressedRistrettoBasepoint, roundtripBasepoint) { if !bytes.Equal(compressedRistrettoBasepoint, roundtripBasepoint) {
t.Error("decode<>encode roundtrip produced different results") t.Error("decode<>encode roundtrip produced different results")
} }
encodedBasepoint := ristrettoBasepoint.Encode() encodedBasepoint := ristrettoBasepoint.Encode(nil)
if !bytes.Equal(compressedRistrettoBasepoint, encodedBasepoint) { if !bytes.Equal(compressedRistrettoBasepoint, encodedBasepoint) {
t.Error("point encode produced different results") t.Error("point encode produced different results")
} }