Add radix51 FieldElement implementation

This contains both a pure Go and an amd64 assembly implementation of
operations over GF(2^255-19) using radix 2^51. This results in notable
speedups when using the assembly, but doesn't help much in pure Go -
most of the possible gains are lost to the lack of widening multiply for
64 bit integers.
This commit is contained in:
George Tankersley 2017-06-25 00:00:00 +00:00
parent 54cdb2aef0
commit ca2029ab22
11 changed files with 1108 additions and 0 deletions

View File

@ -15,6 +15,7 @@ import (
"testing"
"github.com/gtank/ed25519/internal/edwards25519"
"github.com/gtank/ed25519/internal/radix51"
)
// TEST MATH
@ -428,3 +429,40 @@ func BenchmarkScalarMult(b *testing.B) {
_, _ = ed.ScalarMult(Bx, By, k[:])
}
}
// A is a constant from the Montgomery form of curve25519.
var radix25A = edwards25519.FieldElement{
486662, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}
var radix51A = radix51.FieldElement{
486662, 0, 0, 0, 0,
}
func BenchmarkFeMul25(b *testing.B) {
var h edwards25519.FieldElement
for i := 0; i < b.N; i++ {
edwards25519.FeMul(&h, &radix25A, &radix25A)
}
}
func BenchmarkFeMul51(b *testing.B) {
var h radix51.FieldElement
for i := 0; i < b.N; i++ {
radix51.FeMul(&h, &radix51A, &radix51A)
}
}
func BenchmarkFeSquare25(b *testing.B) {
var h edwards25519.FieldElement
for i := 0; i < b.N; i++ {
edwards25519.FeSquare(&h, &radix25A)
}
}
func BenchmarkFeSquare51(b *testing.B) {
var h radix51.FieldElement
for i := 0; i < b.N; i++ {
radix51.FeSquare(&h, &radix51A)
}
}

View File

@ -0,0 +1,9 @@
// Copyright 2017 George Tankersley. All rights reserved.
// Constants used in the implementation of GF(2^255-19) field arithmetic.
package radix51
const (
// The vaule 2^51-1, used in carry propagation
maskLow51Bits = uint64(1)<<51 - 1
)

300
internal/radix51/fe.go Normal file
View File

@ -0,0 +1,300 @@
// Copyright 2017 George Tankersley. All rights reserved.
// Field arithmetic in radix 2^51 representation. This code is a port of the
// public domain amd64-51-30k version of ed25519 from SUPERCOP.
package radix51
// FieldElement represents an element of the field GF(2^255-19). An element t
// represents the integer t[0] + t[1]*2^51 + t[2]*2^102 + t[3]*2^153 +
// t[4]*2^204.
type FieldElement [5]uint64
func (v *FieldElement) FeZero() {
v[0] = 0
v[1] = 0
v[2] = 0
v[3] = 0
v[4] = 0
}
func (v *FieldElement) FeOne() {
v[0] = 1
v[1] = 0
v[2] = 0
v[3] = 0
v[4] = 0
}
// SetInt sets the receiving FieldElement to the specified small integer.
func (v *FieldElement) SetInt(x uint64) {
v[0] = x
v[1] = 0
v[2] = 0
v[3] = 0
v[4] = 0
}
func FeReduce(t, v *FieldElement) {
// Copy v
*t = *v
// Lev v = v[0] + v[1]*2^51 + v[2]*2^102 + v[3]*2^153 + v[4]*2^204
// Reduce each limb below 2^51, propagating carries.
t[1] += t[0] >> 51
t[0] = t[0] & maskLow51Bits
t[2] += t[1] >> 51
t[1] = t[1] & maskLow51Bits
t[3] += t[2] >> 51
t[2] = t[2] & maskLow51Bits
t[4] += t[3] >> 51
t[3] = t[3] & maskLow51Bits
t[0] += (t[4] >> 51) * 19
t[4] = t[4] & maskLow51Bits
// We now hate a field element t < 2^255, but need t <= 2^255-19
// TODO Document why this works. It's the elaborate comment about r = h-pq etc etc.
// Get the carry bit
c := (t[0] + 19) >> 51
c = (t[1] + c) >> 51
c = (t[2] + c) >> 51
c = (t[3] + c) >> 51
c = (t[4] + c) >> 51
t[0] += 19 * c
t[1] += t[0] >> 51
t[0] = t[0] & maskLow51Bits
t[2] += t[1] >> 51
t[1] = t[1] & maskLow51Bits
t[3] += t[2] >> 51
t[2] = t[2] & maskLow51Bits
t[4] += t[3] >> 51
t[3] = t[3] & maskLow51Bits
// no additional carry
t[4] = t[4] & maskLow51Bits
}
// FeAdd sets out = a + b. Long sequences of additions without reduction that
// let coefficients grow larger than 54 bits would be a problem. Paper
// cautions: "do not have such sequences of additions".
func FeAdd(out, a, b *FieldElement) {
out[0] = a[0] + b[0]
out[1] = a[1] + b[1]
out[2] = a[2] + b[2]
out[3] = a[3] + b[3]
out[4] = a[4] + b[4]
}
// FeSub sets out = a - b
func FeSub(out, a, b *FieldElement) {
var t FieldElement
t = *b
// Reduce each limb below 2^51, propagating carries. Ensures that results
// fit within the limbs. This would not be required for reduced input.
t[1] += t[0] >> 51
t[0] = t[0] & maskLow51Bits
t[2] += t[1] >> 51
t[1] = t[1] & maskLow51Bits
t[3] += t[2] >> 51
t[2] = t[2] & maskLow51Bits
t[4] += t[3] >> 51
t[3] = t[3] & maskLow51Bits
t[0] += (t[4] >> 51) * 19
t[4] = t[4] & maskLow51Bits
// This is slightly more complicated. Because we use unsigned coefficients, we
// first add a multiple of p and then subtract.
out[0] = (a[0] + 0xFFFFFFFFFFFDA) - t[0]
out[1] = (a[1] + 0xFFFFFFFFFFFFE) - t[1]
out[2] = (a[2] + 0xFFFFFFFFFFFFE) - t[2]
out[3] = (a[3] + 0xFFFFFFFFFFFFE) - t[3]
out[4] = (a[4] + 0xFFFFFFFFFFFFE) - t[4]
}
// FeNeg sets out = -a
func FeNeg(out, a *FieldElement) {
var t FieldElement
t.SetInt(0)
FeSub(out, &t, a)
}
// FeInvert sets out = 1/z mod p by calculating z^(p-2), p-2 = 2^255 - 21.
func FeInvert(out, z *FieldElement) {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t FieldElement
FeSquare(&z2, z) // 2
FeSquare(&t, &z2) // 4
FeSquare(&t, &t) // 8
FeMul(&z9, &t, z) // 9
FeMul(&z11, &z9, &z2) // 11
FeSquare(&t, &z11) // 22
FeMul(&z2_5_0, &t, &z9) // 2^5 - 2^0 = 31
FeSquare(&t, &z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
FeSquare(&t, &t) // 2^10 - 2^5
}
FeMul(&z2_10_0, &t, &z2_5_0) // 2^10 - 2^0
FeSquare(&t, &z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
FeSquare(&t, &t) // 2^20 - 2^10
}
FeMul(&z2_20_0, &t, &z2_10_0) // 2^20 - 2^0
FeSquare(&t, &z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
FeSquare(&t, &t) // 2^40 - 2^20
}
FeMul(&t, &t, &z2_20_0) // 2^40 - 2^0
FeSquare(&t, &t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
FeSquare(&t, &t) // 2^50 - 2^10
}
FeMul(&z2_50_0, &t, &z2_10_0) // 2^50 - 2^0
FeSquare(&t, &z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
FeSquare(&t, &t) // 2^100 - 2^50
}
FeMul(&z2_100_0, &t, &z2_50_0) // 2^100 - 2^0
FeSquare(&t, &z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
FeSquare(&t, &t) // 2^200 - 2^100
}
FeMul(&t, &t, &z2_100_0) // 2^200 - 2^0
FeSquare(&t, &t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
FeSquare(&t, &t) // 2^250 - 2^50
}
FeMul(&t, &t, &z2_50_0) // 2^250 - 2^0
FeSquare(&t, &t) // 2^251 - 2^1
FeSquare(&t, &t) // 2^252 - 2^2
FeSquare(&t, &t) // 2^253 - 2^3
FeSquare(&t, &t) // 2^254 - 2^4
FeSquare(&t, &t) // 2^255 - 2^5
FeMul(out, &t, &z11) // 2^255 - 21
}
// FeCSwap implements a constant-time conditional swap. That is:
// (a, b) = (b, a) if c = 1
// (a, b) = (a, b) if c = 0
func FeCSwap(a, b *FieldElement, c uint64) {
var t FieldElement
t[0] ^= c & (a[0] ^ b[0])
t[1] ^= c & (a[1] ^ b[1])
t[2] ^= c & (a[2] ^ b[2])
t[3] ^= c & (a[3] ^ b[3])
t[4] ^= c & (a[4] ^ b[4])
a[0] ^= t[0]
a[1] ^= t[1]
a[2] ^= t[2]
a[3] ^= t[3]
a[4] ^= t[4]
b[0] ^= t[0]
b[1] ^= t[1]
b[2] ^= t[2]
b[3] ^= t[3]
b[4] ^= t[4]
}
func FeFromBytes(v *FieldElement, x *[32]byte) {
v[0] = uint64(x[0])
v[0] |= uint64(x[1]) << 8
v[0] |= uint64(x[2]) << 16
v[0] |= uint64(x[3]) << 24
v[0] |= uint64(x[4]) << 32
v[0] |= uint64(x[5]) << 40
v[0] |= uint64(x[6]&7) << 48
v[1] = uint64(x[6]) >> 3
v[1] |= uint64(x[7]) << 5
v[1] |= uint64(x[8]) << 13
v[1] |= uint64(x[9]) << 21
v[1] |= uint64(x[10]) << 29
v[1] |= uint64(x[11]) << 37
v[1] |= uint64(x[12]&63) << 45
v[2] = uint64(x[12]) >> 6
v[2] |= uint64(x[13]) << 2
v[2] |= uint64(x[14]) << 10
v[2] |= uint64(x[15]) << 18
v[2] |= uint64(x[16]) << 26
v[2] |= uint64(x[17]) << 34
v[2] |= uint64(x[18]) << 42
v[2] |= uint64(x[19]&1) << 50
v[3] = uint64(x[19]) >> 1
v[3] |= uint64(x[20]) << 7
v[3] |= uint64(x[21]) << 15
v[3] |= uint64(x[22]) << 23
v[3] |= uint64(x[23]) << 31
v[3] |= uint64(x[24]) << 39
v[3] |= uint64(x[25]&15) << 47
v[4] = uint64(x[25]) >> 4
v[4] |= uint64(x[26]) << 4
v[4] |= uint64(x[27]) << 12
v[4] |= uint64(x[28]) << 20
v[4] |= uint64(x[29]) << 28
v[4] |= uint64(x[30]) << 36
v[4] |= uint64(x[31]&127) << 44
}
func FeToBytes(r *[32]byte, v *FieldElement) {
var t FieldElement
FeReduce(&t, v)
r[0] = byte(t[0] & 0xff)
r[1] = byte((t[0] >> 8) & 0xff)
r[2] = byte((t[0] >> 16) & 0xff)
r[3] = byte((t[0] >> 24) & 0xff)
r[4] = byte((t[0] >> 32) & 0xff)
r[5] = byte((t[0] >> 40) & 0xff)
r[6] = byte((t[0] >> 48))
r[6] ^= byte((t[1] << 3) & 0xf8)
r[7] = byte((t[1] >> 5) & 0xff)
r[8] = byte((t[1] >> 13) & 0xff)
r[9] = byte((t[1] >> 21) & 0xff)
r[10] = byte((t[1] >> 29) & 0xff)
r[11] = byte((t[1] >> 37) & 0xff)
r[12] = byte((t[1] >> 45))
r[12] ^= byte((t[2] << 6) & 0xc0)
r[13] = byte((t[2] >> 2) & 0xff)
r[14] = byte((t[2] >> 10) & 0xff)
r[15] = byte((t[2] >> 18) & 0xff)
r[16] = byte((t[2] >> 26) & 0xff)
r[17] = byte((t[2] >> 34) & 0xff)
r[18] = byte((t[2] >> 42) & 0xff)
r[19] = byte((t[2] >> 50))
r[19] ^= byte((t[3] << 1) & 0xfe)
r[20] = byte((t[3] >> 7) & 0xff)
r[21] = byte((t[3] >> 15) & 0xff)
r[22] = byte((t[3] >> 23) & 0xff)
r[23] = byte((t[3] >> 31) & 0xff)
r[24] = byte((t[3] >> 39) & 0xff)
r[25] = byte((t[3] >> 47))
r[25] ^= byte((t[4] << 4) & 0xf0)
r[26] = byte((t[4] >> 4) & 0xff)
r[27] = byte((t[4] >> 12) & 0xff)
r[28] = byte((t[4] >> 20) & 0xff)
r[29] = byte((t[4] >> 28) & 0xff)
r[30] = byte((t[4] >> 36) & 0xff)
r[31] = byte((t[4] >> 44))
}

122
internal/radix51/fe_mul.go Normal file
View File

@ -0,0 +1,122 @@
// +build !amd64 noasm
package radix51
// FeMul sets out = a * b
func FeMul(out, x, y *FieldElement) {
var x0, x1, x2, x3, x4 uint64
var y0, y1, y2, y3, y4 uint64
x0 = x[0]
x1 = x[1]
x2 = x[2]
x3 = x[3]
x4 = x[4]
y0 = y[0]
y1 = y[1]
y2 = y[2]
y3 = y[3]
y4 = y[4]
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute a coefficient r_5 . Whenever the result of a
// mul instruction belongs to r_5 , for example in the multiplication of
// x_3*y_2 , we multiply one of the inputs by 19 and add the result to r_0.
x1_19 := x1 * 19
x2_19 := x2 * 19
x3_19 := x3 * 19
x4_19 := x4 * 19
// calculate r0 = x0*y0 + 19*(x1*y4 + x2*y3 + x3*y2 + x4*y1)
r00, r01 := mul64x64(0, 0, x0, y0)
r00, r01 = mul64x64(r00, r01, x1_19, y4)
r00, r01 = mul64x64(r00, r01, x2_19, y3)
r00, r01 = mul64x64(r00, r01, x3_19, y2)
r00, r01 = mul64x64(r00, r01, x4_19, y1)
// calculate r1 = x0*y1 + x1*y0 + 19*(x2*y4 + x3*y3 + x4*y2)
r10, r11 := mul64x64(0, 0, x0, y1)
r10, r11 = mul64x64(r10, r11, x1, y0)
r10, r11 = mul64x64(r10, r11, x2_19, y4)
r10, r11 = mul64x64(r10, r11, x3_19, y3)
r10, r11 = mul64x64(r10, r11, x4_19, y2)
// calculate r2 = x0*y2 + x1*y1 + x2*y0 + 19*(x3*y4 + x4*y3)
r20, r21 := mul64x64(0, 0, x0, y2)
r20, r21 = mul64x64(r20, r21, x1, y1)
r20, r21 = mul64x64(r20, r21, x2, y0)
r20, r21 = mul64x64(r20, r21, x3_19, y4)
r20, r21 = mul64x64(r20, r21, x4_19, y3)
// calculate r3 = x0*y3 + x1*y2 + x2*y1 + x3*y0 + 19*x4*y4
r30, r31 := mul64x64(0, 0, x0, y3)
r30, r31 = mul64x64(r30, r31, x1, y2)
r30, r31 = mul64x64(r30, r31, x2, y1)
r30, r31 = mul64x64(r30, r31, x3, y0)
r30, r31 = mul64x64(r30, r31, x4_19, y4)
// calculate r4 = x0*y4 + x1*y3 + x2*y2 + x3*y1 + x4*y0
r40, r41 := mul64x64(0, 0, x0, y4)
r40, r41 = mul64x64(r40, r41, x1, y3)
r40, r41 = mul64x64(r40, r41, x2, y2)
r40, r41 = mul64x64(r40, r41, x3, y1)
r40, r41 = mul64x64(r40, r41, x4, y0)
// After the multiplication we need to reduce (carry) the 5 coefficients to
// obtain a result with coefficients that are at most slightly larger than
// 2^51 . Denote the two registers holding coefficient r_0 as r_00 and r_01
// with r_0 = 2^64*r_01 + r_00 . Similarly denote the two registers holding
// coefficient r_1 as r_10 and r_11 . We first shift r_01 left by 13, while
// shifting in the most significant bits of r_00 (shld instruction) and
// then compute the logical and of r_00 with 2^51 1. We do the same with
// r_10 and r_11 and add r_01 into r_10 after the logical and with 2^51
// 1. We proceed this way for coefficients r_2,...,r_4; register r_41 is
// multiplied by 19 before adding it to r_00 .
r01 = (r01 << 13) | (r00 >> 51)
r00 &= maskLow51Bits
r11 = (r11 << 13) | (r10 >> 51)
r10 &= maskLow51Bits
r10 += r01
r21 = (r21 << 13) | (r20 >> 51)
r20 &= maskLow51Bits
r20 += r11
r31 = (r31 << 13) | (r30 >> 51)
r30 &= maskLow51Bits
r30 += r21
r41 = (r41 << 13) | (r40 >> 51)
r40 &= maskLow51Bits
r40 += r31
r41 *= 19
r00 += r41
// Now all 5 coefficients fit into 64-bit registers but are still too large
// to be used as input to another multiplication. We therefore carry from
// r_0 to r_1 , from r_1 to r_2 , from r_2 to r_3 , from r_3 to r_4 , and
// finally from r_4 to r_0 . Each of these carries is done as one copy, one
// right shift by 51, one logical and with 2^51 1, and one addition.
r10 += r00 >> 51
r00 &= maskLow51Bits
r20 += r10 >> 51
r10 &= maskLow51Bits
r30 += r20 >> 51
r20 &= maskLow51Bits
r40 += r30 >> 51
r30 &= maskLow51Bits
r00 += (r40 >> 51) * 19
r40 &= maskLow51Bits
out[0] = r00
out[1] = r10
out[2] = r20
out[3] = r30
out[4] = r40
}

View File

@ -0,0 +1,6 @@
// +build amd64,!noasm
package radix51
// go:noescape
func FeMul(out, a, b *FieldElement)

View File

@ -0,0 +1,198 @@
// Based on assembly generated by PeachPy. Equivalent to the Go in fe_mul.go,
// which was originally based on the amd64-51-30k assembly in SUPERCOP.
// +build amd64,!noasm
// func FeMul(outp *uint64, xp *uint64, yp *uint64)
TEXT ·FeMul(SB),$0-24
MOVQ outp+0(FP), DI
MOVQ xp+8(FP), BX
MOVQ yp+16(FP), CX
// Calculate r0
MOVQ 0(BX), AX // rax <-- x0
MULQ 0(CX) // rdx, rax <-- x0*y0
MOVQ AX, SI // r00 = rax
MOVQ DX, BP // r01 = rdx
MOVQ 8(BX), DX // rdx <-- x1
IMUL3Q $19, DX, AX // rax <-- x1*19
MULQ 32(CX) // rdx, rax <-- x1_19*y4
ADDQ AX, SI // r00 += rax
ADCQ DX, BP // r01 += rdx
MOVQ 16(BX), DX // rdx <-- x2
IMUL3Q $19, DX, AX // rax <-- x2*19
MULQ 24(CX) // rdx, rax <-- x2_19*y3
ADDQ AX, SI // r00 += rax
ADCQ DX, BP // r01 += rdx
MOVQ 24(BX), DX // rdx <-- x3
IMUL3Q $19, DX, AX // rax <-- x3*19
MULQ 16(CX) // rdx, rax <-- x3_19 * y2
ADDQ AX, SI // r00 += rax
ADCQ DX, BP // r01 += rdx
MOVQ 32(BX), DX // rdx <-- x4
IMUL3Q $19, DX, AX // rax <-- x4*19
MULQ 8(CX) // rdx rax <-- x4_19*y1
ADDQ AX, SI // r00 += rax
ADCQ DX, BP // r01 += rdx
// Calculate r1
MOVQ 0(BX), AX
MULQ 8(CX)
MOVQ AX, R8 // r10
MOVQ DX, R9 // r11
MOVQ 8(BX), AX
MULQ 0(CX)
ADDQ AX, R8
ADCQ DX, R9
MOVQ 16(BX), DX
IMUL3Q $19, DX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, R9
MOVQ 24(BX), DX
IMUL3Q $19, DX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, R9
MOVQ 32(BX), DX
IMUL3Q $19, DX, AX
MULQ 16(CX)
ADDQ AX, R8
ADCQ DX, R9
// Calculate r2
MOVQ 0(BX), AX
MULQ 16(CX)
MOVQ AX, R10 // r20
MOVQ DX, R11 // r21
MOVQ 8(BX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R11
MOVQ 16(BX), AX
MULQ 0(CX)
ADDQ AX, R10
ADCQ DX, R11
MOVQ 24(BX), DX
IMUL3Q $19, DX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R11
MOVQ 32(BX), DX
IMUL3Q $19, DX, AX
MULQ 24(CX)
ADDQ AX, R10
ADCQ DX, R11
// Calculate r3
MOVQ 0(BX), AX
MULQ 24(CX)
MOVQ AX, R12 // r30
MOVQ DX, R13 // r31
MOVQ 8(BX), AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R13
MOVQ 16(BX), AX
MULQ 8(CX)
ADDQ AX, R12
ADCQ DX, R13
MOVQ 24(BX), AX
MULQ 0(CX)
ADDQ AX, R12
ADCQ DX, R13
MOVQ 32(BX), DX
IMUL3Q $19, DX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R13
// Calculate r4
MOVQ 0(BX), AX
MULQ 32(CX)
MOVQ AX, R14 // r40
MOVQ DX, R15 // r41
MOVQ 8(BX), AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R15
MOVQ 16(BX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R15
MOVQ 24(BX), AX
MULQ 8(CX)
ADDQ AX, R14
ADCQ DX, R15
MOVQ 32(BX), AX
MULQ 0(CX)
ADDQ AX, R14
ADCQ DX, R15
MOVQ $2251799813685247, AX // (1<<51) - 1
SHLQ $13, SI, BP // r01 = shld with r00
ANDQ AX, SI // r00 &= mask51
SHLQ $13, R8, R9 // r11 = shld with r10
ANDQ AX, R8 // r10 &= mask51
ADDQ BP, R8 // r10 += r01
SHLQ $13, R10, R11 // r21 = shld with r20
ANDQ AX, R10 // r20 &= mask51
ADDQ R9, R10 // r20 += r11
SHLQ $13, R12, R13 // r31 = shld with r30
ANDQ AX, R12 // r30 &= mask51
ADDQ R11, R12 // r30 += r21
SHLQ $13, R14, R15 // r41 = shld with r40
ANDQ AX, R14 // r40 &= mask51
ADDQ R13, R14 // r40 += r31
IMUL3Q $19, R15, R15 // r41 = r41*19
ADDQ R15, SI // r00 += r41
MOVQ SI, DX // rdx <-- r00
SHRQ $51, DX // rdx <-- r00 >> 51
ADDQ DX, R8 // r10 += r00 >> 51
MOVQ R8, DX // rdx <-- r10
SHRQ $51, DX // rdx <-- r10 >> 51
ANDQ AX, SI // r00 &= mask51
ADDQ DX, R10 // r20 += r10 >> 51
MOVQ R10, DX // rdx <-- r20
SHRQ $51, DX // rdx <-- r20 >> 51
ANDQ AX, R8 // r10 &= mask51
ADDQ DX, R12 // r30 += r20 >> 51
MOVQ R12, DX // rdx <-- r30
SHRQ $51, DX // rdx <-- r30 >> 51
ANDQ AX, R10 // r20 &= mask51
ADDQ DX, R14 // r40 += r30 >> 51
MOVQ R14, DX // rdx <-- r40
SHRQ $51, DX // rdx <-- r40 >> 51
ANDQ AX, R12 // r30 &= mask51
IMUL3Q $19, DX, DX // rdx <-- (r40 >> 51) * 19
ADDQ DX, SI // r00 += (r40 >> 51) *19
ANDQ AX, R14 // r40 &= mask51
MOVQ SI, 0(DI)
MOVQ R8, 8(DI)
MOVQ R10, 16(DI)
MOVQ R12, 24(DI)
MOVQ R14, 32(DI)
RET

View File

@ -0,0 +1,94 @@
// +build !amd64 noasm
package radix51
// FeSquare sets out = x*x
func FeSquare(out, x *FieldElement) {
// Squaring needs only 15 mul instructions. Some inputs are multiplied by 2;
// this is combined with multiplication by 19 where possible. The coefficient
// reduction after squaring is the same as for multiplication.
var x0, x1, x2, x3, x4 uint64
x0 = x[0]
x1 = x[1]
x2 = x[2]
x3 = x[3]
x4 = x[4]
x0_2 := x0 << 1
x1_2 := x1 << 1
x1_38 := x1 * 38
x2_38 := x2 * 38
x3_38 := x3 * 38
x3_19 := x3 * 19
x4_19 := x4 * 19
// r0 = x0*x0 + x1*38*x4 + x2*38*x3
r00, r01 := mul64x64(0, 0, x0, x0)
r00, r01 = mul64x64(r00, r01, x1_38, x4)
r00, r01 = mul64x64(r00, r01, x2_38, x3)
// r1 = x0*2*x1 + x2*38*x4 + x3*19*x3
r10, r11 := mul64x64(0, 0, x0_2, x1)
r10, r11 = mul64x64(r10, r11, x2_38, x4)
r10, r11 = mul64x64(r10, r11, x3_19, x3)
// r2 = x0*2*x2 + x1*x1 + x3*38*x4
r20, r21 := mul64x64(0, 0, x0_2, x2)
r20, r21 = mul64x64(r20, r21, x1, x1)
r20, r21 = mul64x64(r20, r21, x3_38, x4)
// r3 = x0*2*x3 + x1*2*x2 + x4*19*x4
r30, r31 := mul64x64(0, 0, x0_2, x3)
r30, r31 = mul64x64(r30, r31, x1_2, x2)
r30, r31 = mul64x64(r30, r31, x4_19, x4)
// r4 = x0*2*x4 + x1*2*x3 + x2*x2
r40, r41 := mul64x64(0, 0, x0_2, x4)
r40, r41 = mul64x64(r40, r41, x1_2, x3)
r40, r41 = mul64x64(r40, r41, x2, x2)
// Same reduction
r01 = (r01 << 13) | (r00 >> 51)
r00 &= maskLow51Bits
r11 = (r11 << 13) | (r10 >> 51)
r10 &= maskLow51Bits
r10 += r01
r21 = (r21 << 13) | (r20 >> 51)
r20 &= maskLow51Bits
r20 += r11
r31 = (r31 << 13) | (r30 >> 51)
r30 &= maskLow51Bits
r30 += r21
r41 = (r41 << 13) | (r40 >> 51)
r40 &= maskLow51Bits
r40 += r31
r41 *= 19
r00 += r41
r10 += r00 >> 51
r00 &= maskLow51Bits
r20 += r10 >> 51
r10 &= maskLow51Bits
r30 += r20 >> 51
r20 &= maskLow51Bits
r40 += r30 >> 51
r30 &= maskLow51Bits
r00 += (r40 >> 51) * 19
r40 &= maskLow51Bits
out[0] = r00
out[1] = r10
out[2] = r20
out[3] = r30
out[4] = r40
}

View File

@ -0,0 +1,6 @@
// +build amd64,!noasm
package radix51
// go:noescape
func FeSquare(out, x *FieldElement)

View File

@ -0,0 +1,146 @@
// +build amd64,!noasm
// func FeSquare(outp *uint64, xp *uint64)
TEXT ·FeSquare(SB),4,$0-16
MOVQ outp+0(FP), DI
MOVQ xp+8(FP), SI
// r0 = x0*x0 + x1*38*x4 + x2*38*x3
MOVQ 0(SI), AX
MULQ 0(SI)
MOVQ AX, CX // r00
MOVQ DX, R8 // r01
MOVQ 8(SI), DX
IMUL3Q $38, DX, AX
MULQ 32(SI)
ADDQ AX, CX
ADCQ DX, R8
MOVQ 16(SI), DX
IMUL3Q $38, DX, AX
MULQ 24(SI)
ADDQ AX, CX
ADCQ DX, R8
// r1 = x0*2*x1 + x2*38*x4 + x3*19*x3
MOVQ 0(SI), AX
SHLQ $1, AX
MULQ 8(SI)
MOVQ AX, R9 // r10
MOVQ DX, R10 // r11
MOVQ 16(SI), DX
IMUL3Q $38, DX, AX
MULQ 32(SI)
ADDQ AX, R9
ADCQ DX, R10
MOVQ 24(SI), DX
IMUL3Q $19, DX, AX
MULQ 24(SI)
ADDQ AX, R9
ADCQ DX, R10
// r2 = x0*2*x2 + x1*x1 + x3*38*x4
MOVQ 0(SI), AX
SHLQ $1, AX
MULQ 16(SI)
MOVQ AX, R11 // r20
MOVQ DX, R12 // r21
MOVQ 8(SI), AX
MULQ 8(SI)
ADDQ AX, R11
ADCQ DX, R12
MOVQ 24(SI), DX
IMUL3Q $38, DX, AX
MULQ 32(SI)
ADDQ AX, R11
ADCQ DX, R12
// r3 = x0*2*x3 + x1*2*x2 + x4*19*x4
MOVQ 0(SI), AX
SHLQ $1, AX
MULQ 24(SI)
MOVQ AX, R13 // r30
MOVQ DX, R14 // r31
MOVQ 8(SI), AX
SHLQ $1, AX
MULQ 16(SI)
ADDQ AX, R13
ADCQ DX, R14
MOVQ 32(SI), DX
IMUL3Q $19, DX, AX
MULQ 32(SI)
ADDQ AX, R13
ADCQ DX, R14
// r4 = x0*2*x4 + x1*2*x3 + x2*x2
MOVQ 0(SI), AX
SHLQ $1, AX
MULQ 32(SI)
MOVQ AX, R15 // r40
MOVQ DX, BX // r41
MOVQ 8(SI), AX
SHLQ $1, AX
MULQ 24(SI)
ADDQ AX, R15
ADCQ DX, BX
MOVQ 16(SI), AX
MULQ 16(SI)
ADDQ AX, R15
ADCQ DX, BX
// Reduce
MOVQ $2251799813685247, AX // (1<<51) - 1
SHLQ $13, CX, R8 // r01 = shld with r00
ANDQ AX, CX // r00 &= mask51
SHLQ $13, R9, R10 // r11 = shld with r10
ANDQ AX, R9 // r10 &= mask51
ADDQ R8, R9 // r10 += r01
SHLQ $13, R11, R12 // r21 = shld with r20
ANDQ AX, R11 // r20 &= mask51
ADDQ R10, R11 // r20 += r11
SHLQ $13, R13, R14 // r31 = shld with r30
ANDQ AX, R13 // r30 &= mask51
ADDQ R12, R13 // r30 += r21
SHLQ $13, R15, BX // r41 = shld with r40
ANDQ AX, R15 // r40 &= mask51
ADDQ R14, R15 // r40 += r31
IMUL3Q $19, BX, DX // r41 = r41*19
ADDQ DX, CX // r00 += r41
MOVQ CX, DX // rdx <-- r00
SHRQ $51, DX // rdx <-- r00 >> 51
ADDQ DX, R9 // r10 += r00 >> 51
MOVQ R9, DX // rdx <-- r10
SHRQ $51, DX // rdx <-- r10 >> 51
ANDQ AX, CX // r00 &= mask51
ADDQ DX, R11 // r20 += r10 >> 51
MOVQ R11, DX // rdx <-- r20
SHRQ $51, DX // rdx <-- r20 >> 51
ANDQ AX, R9 // r10 &= mask51
ADDQ DX, R13 // r30 += r20 >> 51
MOVQ R13, DX // rdx <-- r30
SHRQ $51, DX // rdx <-- r30 >> 51
ANDQ AX, R11 // r20 &= mask51
ADDQ DX, R15 // r40 += r30 >> 51
MOVQ R15, DX // rdx <-- r40
SHRQ $51, DX // rdx <-- r40 >> 51
ANDQ AX, R13 // r30 &= mask51
IMUL3Q $19, DX, DX // rdx <-- (r40 >> 51) * 19
ADDQ DX, CX // r00 += (r40 >> 51) *19
ANDQ AX, R15 // r40 &= mask51
MOVQ CX, 0(DI)
MOVQ R9, 8(DI)
MOVQ R11, 16(DI)
MOVQ R13, 24(DI)
MOVQ R15, 32(DI)
RET

175
internal/radix51/fe_test.go Normal file
View File

@ -0,0 +1,175 @@
package radix51
import (
"bytes"
"crypto/rand"
"io"
"testing"
"unsafe"
)
func TestMul64to128(t *testing.T) {
a := uint64(5)
b := uint64(5)
r0, r1 := mul64x64(0, 0, a, b)
if r0 != 0x19 || r1 != 0 {
t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r0, r1)
}
a = uint64(18014398509481983) // 2^54 - 1
b = uint64(18014398509481983) // 2^54 - 1
r0, r1 = mul64x64(0, 0, a, b)
if r0 != 0xff80000000000001 || r1 != 0xfffffffffff {
t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r0, r1)
}
a = uint64(1125899906842661)
b = uint64(2097155)
r0, r1 = mul64x64(0, 0, a, b)
r0, r1 = mul64x64(r0, r1, a, b)
r0, r1 = mul64x64(r0, r1, a, b)
r0, r1 = mul64x64(r0, r1, a, b)
r0, r1 = mul64x64(r0, r1, a, b)
if r0 != 16888498990613035 || r1 != 640 {
t.Errorf("wrong answer: %d + %d*(2**64)", r0, r1)
}
}
func BenchmarkWideMultInline(t *testing.B) {
var r0, r1, ol, oh uint64
a := uint64(18014398509481983) // 2^54 - 1
b := uint64(18014398509481983) // 2^54 - 1
for i := 0; i < t.N; i++ {
t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32)
t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF)
ol = (a * b) + r0
cmp := ol < r0
oh = r1 + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp)))
r1 = oh
r0 = ol
}
}
func BenchmarkWideMultCall(t *testing.B) {
var r0, r1 uint64
a := uint64(18014398509481983)
b := uint64(18014398509481983)
for i := 0; i < t.N; i++ {
r0, r1 = mul64x64(r0, r1, a, b)
}
}
func TestFeFromBytesRoundTrip(t *testing.T) {
var in, out [32]byte
var fe, r FieldElement
in = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
FeFromBytes(&fe, &in)
FeToBytes(&out, &fe)
if !bytes.Equal(in[:], out[:]) {
t.Error("Bytes<>FE doesn't roundtrip")
}
// Random field element
fe[0] = 0x4e645be9215a2
fe[1] = 0x4e9654922df12
fe[2] = 0x5829e468b0205
fe[3] = 0x5e8fca9e0881c
fe[4] = 0x5c490f087d796
FeToBytes(&out, &fe)
FeFromBytes(&r, &out)
for i := 0; i < len(fe); i++ {
if r[i] != fe[i] {
t.Error("FE<>Bytes doesn't roundtrip")
}
}
}
// Tests self-consistency between FeMul and FeSquare.
func TestSanity(t *testing.T) {
var x FieldElement
var x2, x2sq FieldElement
// var x2Go, x2sqGo FieldElement
x = [5]uint64{1, 1, 1, 1, 1}
FeMul(&x2, &x, &x)
// FeMulGo(&x2Go, &x, &x)
FeSquare(&x2sq, &x)
// FeSquareGo(&x2sqGo, &x)
// if !vartimeEqual(x2, x2Go) || !vartimeEqual(x2sq, x2sqGo) || !vartimeEqual(x2, x2sq) {
// t.Fatalf("all ones failed\nmul.s: %d\nmul.g: %d\nsqr.s: %d\nsqr.g: %d\n", x2, x2Go, x2sq, x2sqGo)
// }
if !vartimeEqual(x2, x2sq) {
t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
}
var bytes [32]byte
_, err := io.ReadFull(rand.Reader, bytes[:])
if err != nil {
t.Fatal(err)
}
FeFromBytes(&x, &bytes)
FeMul(&x2, &x, &x)
// FeMulGo(&x2Go, &x, &x)
FeSquare(&x2sq, &x)
// FeSquareGo(&x2sqGo, &x)
// if !vartimeEqual(x2, x2Go) || !vartimeEqual(x2sq, x2sqGo) || !vartimeEqual(x2, x2sq) {
// t.Fatalf("random field element failed\nfe: %x\n\nmul.s: %x\nmul.g: %x\nsqr.s: %x\nsqr.g: %x\n", x, x2, x2Go, x2sq, x2sqGo)
// }
if !vartimeEqual(x2, x2sq) {
t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
}
}
func vartimeEqual(x, y FieldElement) bool {
for i := 0; i < 5; i++ {
if x[i] != y[i] {
return false
}
}
return true
}
func TestFeInvert(t *testing.T) {
var x FieldElement = [5]uint64{1, 1, 1, 1, 1}
var one FieldElement = [5]uint64{1, 0, 0, 0, 0}
var xinv, r FieldElement
FeInvert(&xinv, &x)
FeMul(&r, &x, &xinv)
FeReduce(&r, &r)
if !vartimeEqual(one, r) {
t.Errorf("inversion identity failed, got: %x", r)
}
var bytes [32]byte
_, err := io.ReadFull(rand.Reader, bytes[:])
if err != nil {
t.Fatal(err)
}
FeFromBytes(&x, &bytes)
FeInvert(&xinv, &x)
FeMul(&r, &x, &xinv)
FeReduce(&r, &r)
if !vartimeEqual(one, r) {
t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
}
}

14
internal/radix51/mul.go Normal file
View File

@ -0,0 +1,14 @@
package radix51
import "unsafe"
// mul64x64 multiples two 64-bit numbers and adds them to two accumulators.
// This function is written to ensure it inlines. I am so sorry.
func mul64x64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32)
t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF)
ol = (a * b) + lo
cmp := ol < lo
oh = hi + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp)))
return
}