diff --git a/ed25519_test.go b/ed25519_test.go index 217ff75..0702e5b 100644 --- a/ed25519_test.go +++ b/ed25519_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/gtank/ed25519/internal/edwards25519" + "github.com/gtank/ed25519/internal/radix51" ) // TEST MATH @@ -428,3 +429,40 @@ func BenchmarkScalarMult(b *testing.B) { _, _ = ed.ScalarMult(Bx, By, k[:]) } } + +// A is a constant from the Montgomery form of curve25519. +var radix25A = edwards25519.FieldElement{ + 486662, 0, 0, 0, 0, 0, 0, 0, 0, 0, +} + +var radix51A = radix51.FieldElement{ + 486662, 0, 0, 0, 0, +} + +func BenchmarkFeMul25(b *testing.B) { + var h edwards25519.FieldElement + for i := 0; i < b.N; i++ { + edwards25519.FeMul(&h, &radix25A, &radix25A) + } +} + +func BenchmarkFeMul51(b *testing.B) { + var h radix51.FieldElement + for i := 0; i < b.N; i++ { + radix51.FeMul(&h, &radix51A, &radix51A) + } +} + +func BenchmarkFeSquare25(b *testing.B) { + var h edwards25519.FieldElement + for i := 0; i < b.N; i++ { + edwards25519.FeSquare(&h, &radix25A) + } +} + +func BenchmarkFeSquare51(b *testing.B) { + var h radix51.FieldElement + for i := 0; i < b.N; i++ { + radix51.FeSquare(&h, &radix51A) + } +} diff --git a/internal/radix51/const.go b/internal/radix51/const.go new file mode 100644 index 0000000..9b3f6b2 --- /dev/null +++ b/internal/radix51/const.go @@ -0,0 +1,9 @@ +// Copyright 2017 George Tankersley. All rights reserved. + +// Constants used in the implementation of GF(2^255-19) field arithmetic. +package radix51 + +const ( + // The vaule 2^51-1, used in carry propagation + maskLow51Bits = uint64(1)<<51 - 1 +) diff --git a/internal/radix51/fe.go b/internal/radix51/fe.go new file mode 100644 index 0000000..09ebc86 --- /dev/null +++ b/internal/radix51/fe.go @@ -0,0 +1,300 @@ +// Copyright 2017 George Tankersley. All rights reserved. + +// Field arithmetic in radix 2^51 representation. This code is a port of the +// public domain amd64-51-30k version of ed25519 from SUPERCOP. +package radix51 + +// FieldElement represents an element of the field GF(2^255-19). An element t +// represents the integer t[0] + t[1]*2^51 + t[2]*2^102 + t[3]*2^153 + +// t[4]*2^204. +type FieldElement [5]uint64 + +func (v *FieldElement) FeZero() { + v[0] = 0 + v[1] = 0 + v[2] = 0 + v[3] = 0 + v[4] = 0 +} + +func (v *FieldElement) FeOne() { + v[0] = 1 + v[1] = 0 + v[2] = 0 + v[3] = 0 + v[4] = 0 +} + +// SetInt sets the receiving FieldElement to the specified small integer. +func (v *FieldElement) SetInt(x uint64) { + v[0] = x + v[1] = 0 + v[2] = 0 + v[3] = 0 + v[4] = 0 +} + +func FeReduce(t, v *FieldElement) { + // Copy v + *t = *v + + // Lev v = v[0] + v[1]*2^51 + v[2]*2^102 + v[3]*2^153 + v[4]*2^204 + // Reduce each limb below 2^51, propagating carries. + t[1] += t[0] >> 51 + t[0] = t[0] & maskLow51Bits + t[2] += t[1] >> 51 + t[1] = t[1] & maskLow51Bits + t[3] += t[2] >> 51 + t[2] = t[2] & maskLow51Bits + t[4] += t[3] >> 51 + t[3] = t[3] & maskLow51Bits + t[0] += (t[4] >> 51) * 19 + t[4] = t[4] & maskLow51Bits + + // We now hate a field element t < 2^255, but need t <= 2^255-19 + // TODO Document why this works. It's the elaborate comment about r = h-pq etc etc. + + // Get the carry bit + c := (t[0] + 19) >> 51 + c = (t[1] + c) >> 51 + c = (t[2] + c) >> 51 + c = (t[3] + c) >> 51 + c = (t[4] + c) >> 51 + + t[0] += 19 * c + + t[1] += t[0] >> 51 + t[0] = t[0] & maskLow51Bits + t[2] += t[1] >> 51 + t[1] = t[1] & maskLow51Bits + t[3] += t[2] >> 51 + t[2] = t[2] & maskLow51Bits + t[4] += t[3] >> 51 + t[3] = t[3] & maskLow51Bits + // no additional carry + t[4] = t[4] & maskLow51Bits +} + +// FeAdd sets out = a + b. Long sequences of additions without reduction that +// let coefficients grow larger than 54 bits would be a problem. Paper +// cautions: "do not have such sequences of additions". +func FeAdd(out, a, b *FieldElement) { + out[0] = a[0] + b[0] + out[1] = a[1] + b[1] + out[2] = a[2] + b[2] + out[3] = a[3] + b[3] + out[4] = a[4] + b[4] +} + +// FeSub sets out = a - b +func FeSub(out, a, b *FieldElement) { + var t FieldElement + t = *b + + // Reduce each limb below 2^51, propagating carries. Ensures that results + // fit within the limbs. This would not be required for reduced input. + t[1] += t[0] >> 51 + t[0] = t[0] & maskLow51Bits + t[2] += t[1] >> 51 + t[1] = t[1] & maskLow51Bits + t[3] += t[2] >> 51 + t[2] = t[2] & maskLow51Bits + t[4] += t[3] >> 51 + t[3] = t[3] & maskLow51Bits + t[0] += (t[4] >> 51) * 19 + t[4] = t[4] & maskLow51Bits + + // This is slightly more complicated. Because we use unsigned coefficients, we + // first add a multiple of p and then subtract. + out[0] = (a[0] + 0xFFFFFFFFFFFDA) - t[0] + out[1] = (a[1] + 0xFFFFFFFFFFFFE) - t[1] + out[2] = (a[2] + 0xFFFFFFFFFFFFE) - t[2] + out[3] = (a[3] + 0xFFFFFFFFFFFFE) - t[3] + out[4] = (a[4] + 0xFFFFFFFFFFFFE) - t[4] +} + +// FeNeg sets out = -a +func FeNeg(out, a *FieldElement) { + var t FieldElement + t.SetInt(0) + FeSub(out, &t, a) +} + +// FeInvert sets out = 1/z mod p by calculating z^(p-2), p-2 = 2^255 - 21. +func FeInvert(out, z *FieldElement) { + // Inversion is implemented as exponentiation with exponent p − 2. It uses the + // same sequence of 255 squarings and 11 multiplications as [Curve25519]. + var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t FieldElement + + FeSquare(&z2, z) // 2 + FeSquare(&t, &z2) // 4 + FeSquare(&t, &t) // 8 + FeMul(&z9, &t, z) // 9 + FeMul(&z11, &z9, &z2) // 11 + FeSquare(&t, &z11) // 22 + FeMul(&z2_5_0, &t, &z9) // 2^5 - 2^0 = 31 + + FeSquare(&t, &z2_5_0) // 2^6 - 2^1 + for i := 0; i < 4; i++ { + FeSquare(&t, &t) // 2^10 - 2^5 + } + FeMul(&z2_10_0, &t, &z2_5_0) // 2^10 - 2^0 + + FeSquare(&t, &z2_10_0) // 2^11 - 2^1 + for i := 0; i < 9; i++ { + FeSquare(&t, &t) // 2^20 - 2^10 + } + FeMul(&z2_20_0, &t, &z2_10_0) // 2^20 - 2^0 + + FeSquare(&t, &z2_20_0) // 2^21 - 2^1 + for i := 0; i < 19; i++ { + FeSquare(&t, &t) // 2^40 - 2^20 + } + FeMul(&t, &t, &z2_20_0) // 2^40 - 2^0 + + FeSquare(&t, &t) // 2^41 - 2^1 + for i := 0; i < 9; i++ { + FeSquare(&t, &t) // 2^50 - 2^10 + } + FeMul(&z2_50_0, &t, &z2_10_0) // 2^50 - 2^0 + + FeSquare(&t, &z2_50_0) // 2^51 - 2^1 + for i := 0; i < 49; i++ { + FeSquare(&t, &t) // 2^100 - 2^50 + } + FeMul(&z2_100_0, &t, &z2_50_0) // 2^100 - 2^0 + + FeSquare(&t, &z2_100_0) // 2^101 - 2^1 + for i := 0; i < 99; i++ { + FeSquare(&t, &t) // 2^200 - 2^100 + } + FeMul(&t, &t, &z2_100_0) // 2^200 - 2^0 + + FeSquare(&t, &t) // 2^201 - 2^1 + for i := 0; i < 49; i++ { + FeSquare(&t, &t) // 2^250 - 2^50 + } + FeMul(&t, &t, &z2_50_0) // 2^250 - 2^0 + + FeSquare(&t, &t) // 2^251 - 2^1 + FeSquare(&t, &t) // 2^252 - 2^2 + FeSquare(&t, &t) // 2^253 - 2^3 + FeSquare(&t, &t) // 2^254 - 2^4 + FeSquare(&t, &t) // 2^255 - 2^5 + + FeMul(out, &t, &z11) // 2^255 - 21 +} + +// FeCSwap implements a constant-time conditional swap. That is: +// (a, b) = (b, a) if c = 1 +// (a, b) = (a, b) if c = 0 +func FeCSwap(a, b *FieldElement, c uint64) { + var t FieldElement + + t[0] ^= c & (a[0] ^ b[0]) + t[1] ^= c & (a[1] ^ b[1]) + t[2] ^= c & (a[2] ^ b[2]) + t[3] ^= c & (a[3] ^ b[3]) + t[4] ^= c & (a[4] ^ b[4]) + + a[0] ^= t[0] + a[1] ^= t[1] + a[2] ^= t[2] + a[3] ^= t[3] + a[4] ^= t[4] + + b[0] ^= t[0] + b[1] ^= t[1] + b[2] ^= t[2] + b[3] ^= t[3] + b[4] ^= t[4] +} + +func FeFromBytes(v *FieldElement, x *[32]byte) { + v[0] = uint64(x[0]) + v[0] |= uint64(x[1]) << 8 + v[0] |= uint64(x[2]) << 16 + v[0] |= uint64(x[3]) << 24 + v[0] |= uint64(x[4]) << 32 + v[0] |= uint64(x[5]) << 40 + v[0] |= uint64(x[6]&7) << 48 + + v[1] = uint64(x[6]) >> 3 + v[1] |= uint64(x[7]) << 5 + v[1] |= uint64(x[8]) << 13 + v[1] |= uint64(x[9]) << 21 + v[1] |= uint64(x[10]) << 29 + v[1] |= uint64(x[11]) << 37 + v[1] |= uint64(x[12]&63) << 45 + + v[2] = uint64(x[12]) >> 6 + v[2] |= uint64(x[13]) << 2 + v[2] |= uint64(x[14]) << 10 + v[2] |= uint64(x[15]) << 18 + v[2] |= uint64(x[16]) << 26 + v[2] |= uint64(x[17]) << 34 + v[2] |= uint64(x[18]) << 42 + v[2] |= uint64(x[19]&1) << 50 + + v[3] = uint64(x[19]) >> 1 + v[3] |= uint64(x[20]) << 7 + v[3] |= uint64(x[21]) << 15 + v[3] |= uint64(x[22]) << 23 + v[3] |= uint64(x[23]) << 31 + v[3] |= uint64(x[24]) << 39 + v[3] |= uint64(x[25]&15) << 47 + + v[4] = uint64(x[25]) >> 4 + v[4] |= uint64(x[26]) << 4 + v[4] |= uint64(x[27]) << 12 + v[4] |= uint64(x[28]) << 20 + v[4] |= uint64(x[29]) << 28 + v[4] |= uint64(x[30]) << 36 + v[4] |= uint64(x[31]&127) << 44 +} + +func FeToBytes(r *[32]byte, v *FieldElement) { + var t FieldElement + FeReduce(&t, v) + + r[0] = byte(t[0] & 0xff) + r[1] = byte((t[0] >> 8) & 0xff) + r[2] = byte((t[0] >> 16) & 0xff) + r[3] = byte((t[0] >> 24) & 0xff) + r[4] = byte((t[0] >> 32) & 0xff) + r[5] = byte((t[0] >> 40) & 0xff) + r[6] = byte((t[0] >> 48)) + + r[6] ^= byte((t[1] << 3) & 0xf8) + r[7] = byte((t[1] >> 5) & 0xff) + r[8] = byte((t[1] >> 13) & 0xff) + r[9] = byte((t[1] >> 21) & 0xff) + r[10] = byte((t[1] >> 29) & 0xff) + r[11] = byte((t[1] >> 37) & 0xff) + r[12] = byte((t[1] >> 45)) + + r[12] ^= byte((t[2] << 6) & 0xc0) + r[13] = byte((t[2] >> 2) & 0xff) + r[14] = byte((t[2] >> 10) & 0xff) + r[15] = byte((t[2] >> 18) & 0xff) + r[16] = byte((t[2] >> 26) & 0xff) + r[17] = byte((t[2] >> 34) & 0xff) + r[18] = byte((t[2] >> 42) & 0xff) + r[19] = byte((t[2] >> 50)) + + r[19] ^= byte((t[3] << 1) & 0xfe) + r[20] = byte((t[3] >> 7) & 0xff) + r[21] = byte((t[3] >> 15) & 0xff) + r[22] = byte((t[3] >> 23) & 0xff) + r[23] = byte((t[3] >> 31) & 0xff) + r[24] = byte((t[3] >> 39) & 0xff) + r[25] = byte((t[3] >> 47)) + + r[25] ^= byte((t[4] << 4) & 0xf0) + r[26] = byte((t[4] >> 4) & 0xff) + r[27] = byte((t[4] >> 12) & 0xff) + r[28] = byte((t[4] >> 20) & 0xff) + r[29] = byte((t[4] >> 28) & 0xff) + r[30] = byte((t[4] >> 36) & 0xff) + r[31] = byte((t[4] >> 44)) +} diff --git a/internal/radix51/fe_mul.go b/internal/radix51/fe_mul.go new file mode 100644 index 0000000..68b6849 --- /dev/null +++ b/internal/radix51/fe_mul.go @@ -0,0 +1,122 @@ +// +build !amd64 noasm + +package radix51 + +// FeMul sets out = a * b +func FeMul(out, x, y *FieldElement) { + var x0, x1, x2, x3, x4 uint64 + var y0, y1, y2, y3, y4 uint64 + + x0 = x[0] + x1 = x[1] + x2 = x[2] + x3 = x[3] + x4 = x[4] + + y0 = y[0] + y1 = y[1] + y2 = y[2] + y3 = y[3] + y4 = y[4] + + // Reduction can be carried out simultaneously to multiplication. For + // example, we do not compute a coefficient r_5 . Whenever the result of a + // mul instruction belongs to r_5 , for example in the multiplication of + // x_3*y_2 , we multiply one of the inputs by 19 and add the result to r_0. + + x1_19 := x1 * 19 + x2_19 := x2 * 19 + x3_19 := x3 * 19 + x4_19 := x4 * 19 + + // calculate r0 = x0*y0 + 19*(x1*y4 + x2*y3 + x3*y2 + x4*y1) + r00, r01 := mul64x64(0, 0, x0, y0) + r00, r01 = mul64x64(r00, r01, x1_19, y4) + r00, r01 = mul64x64(r00, r01, x2_19, y3) + r00, r01 = mul64x64(r00, r01, x3_19, y2) + r00, r01 = mul64x64(r00, r01, x4_19, y1) + + // calculate r1 = x0*y1 + x1*y0 + 19*(x2*y4 + x3*y3 + x4*y2) + r10, r11 := mul64x64(0, 0, x0, y1) + r10, r11 = mul64x64(r10, r11, x1, y0) + r10, r11 = mul64x64(r10, r11, x2_19, y4) + r10, r11 = mul64x64(r10, r11, x3_19, y3) + r10, r11 = mul64x64(r10, r11, x4_19, y2) + + // calculate r2 = x0*y2 + x1*y1 + x2*y0 + 19*(x3*y4 + x4*y3) + r20, r21 := mul64x64(0, 0, x0, y2) + r20, r21 = mul64x64(r20, r21, x1, y1) + r20, r21 = mul64x64(r20, r21, x2, y0) + r20, r21 = mul64x64(r20, r21, x3_19, y4) + r20, r21 = mul64x64(r20, r21, x4_19, y3) + + // calculate r3 = x0*y3 + x1*y2 + x2*y1 + x3*y0 + 19*x4*y4 + r30, r31 := mul64x64(0, 0, x0, y3) + r30, r31 = mul64x64(r30, r31, x1, y2) + r30, r31 = mul64x64(r30, r31, x2, y1) + r30, r31 = mul64x64(r30, r31, x3, y0) + r30, r31 = mul64x64(r30, r31, x4_19, y4) + + // calculate r4 = x0*y4 + x1*y3 + x2*y2 + x3*y1 + x4*y0 + r40, r41 := mul64x64(0, 0, x0, y4) + r40, r41 = mul64x64(r40, r41, x1, y3) + r40, r41 = mul64x64(r40, r41, x2, y2) + r40, r41 = mul64x64(r40, r41, x3, y1) + r40, r41 = mul64x64(r40, r41, x4, y0) + + // After the multiplication we need to reduce (carry) the 5 coefficients to + // obtain a result with coefficients that are at most slightly larger than + // 2^51 . Denote the two registers holding coefficient r_0 as r_00 and r_01 + // with r_0 = 2^64*r_01 + r_00 . Similarly denote the two registers holding + // coefficient r_1 as r_10 and r_11 . We first shift r_01 left by 13, while + // shifting in the most significant bits of r_00 (shld instruction) and + // then compute the logical and of r_00 with 2^51 − 1. We do the same with + // r_10 and r_11 and add r_01 into r_10 after the logical and with 2^51 − + // 1. We proceed this way for coefficients r_2,...,r_4; register r_41 is + // multiplied by 19 before adding it to r_00 . + + r01 = (r01 << 13) | (r00 >> 51) + r00 &= maskLow51Bits + + r11 = (r11 << 13) | (r10 >> 51) + r10 &= maskLow51Bits + r10 += r01 + + r21 = (r21 << 13) | (r20 >> 51) + r20 &= maskLow51Bits + r20 += r11 + + r31 = (r31 << 13) | (r30 >> 51) + r30 &= maskLow51Bits + r30 += r21 + + r41 = (r41 << 13) | (r40 >> 51) + r40 &= maskLow51Bits + r40 += r31 + + r41 *= 19 + r00 += r41 + + // Now all 5 coefficients fit into 64-bit registers but are still too large + // to be used as input to another multiplication. We therefore carry from + // r_0 to r_1 , from r_1 to r_2 , from r_2 to r_3 , from r_3 to r_4 , and + // finally from r_4 to r_0 . Each of these carries is done as one copy, one + // right shift by 51, one logical and with 2^51 − 1, and one addition. + + r10 += r00 >> 51 + r00 &= maskLow51Bits + r20 += r10 >> 51 + r10 &= maskLow51Bits + r30 += r20 >> 51 + r20 &= maskLow51Bits + r40 += r30 >> 51 + r30 &= maskLow51Bits + r00 += (r40 >> 51) * 19 + r40 &= maskLow51Bits + + out[0] = r00 + out[1] = r10 + out[2] = r20 + out[3] = r30 + out[4] = r40 +} diff --git a/internal/radix51/fe_mul_amd64.go b/internal/radix51/fe_mul_amd64.go new file mode 100644 index 0000000..c76b4db --- /dev/null +++ b/internal/radix51/fe_mul_amd64.go @@ -0,0 +1,6 @@ +// +build amd64,!noasm + +package radix51 + +// go:noescape +func FeMul(out, a, b *FieldElement) diff --git a/internal/radix51/fe_mul_amd64.s b/internal/radix51/fe_mul_amd64.s new file mode 100644 index 0000000..07ebb0c --- /dev/null +++ b/internal/radix51/fe_mul_amd64.s @@ -0,0 +1,198 @@ +// Based on assembly generated by PeachPy. Equivalent to the Go in fe_mul.go, +// which was originally based on the amd64-51-30k assembly in SUPERCOP. + +// +build amd64,!noasm + +// func FeMul(outp *uint64, xp *uint64, yp *uint64) +TEXT ·FeMul(SB),$0-24 + MOVQ outp+0(FP), DI + MOVQ xp+8(FP), BX + MOVQ yp+16(FP), CX + + // Calculate r0 + MOVQ 0(BX), AX // rax <-- x0 + MULQ 0(CX) // rdx, rax <-- x0*y0 + MOVQ AX, SI // r00 = rax + MOVQ DX, BP // r01 = rdx + + MOVQ 8(BX), DX // rdx <-- x1 + IMUL3Q $19, DX, AX // rax <-- x1*19 + MULQ 32(CX) // rdx, rax <-- x1_19*y4 + ADDQ AX, SI // r00 += rax + ADCQ DX, BP // r01 += rdx + + MOVQ 16(BX), DX // rdx <-- x2 + IMUL3Q $19, DX, AX // rax <-- x2*19 + MULQ 24(CX) // rdx, rax <-- x2_19*y3 + ADDQ AX, SI // r00 += rax + ADCQ DX, BP // r01 += rdx + + MOVQ 24(BX), DX // rdx <-- x3 + IMUL3Q $19, DX, AX // rax <-- x3*19 + MULQ 16(CX) // rdx, rax <-- x3_19 * y2 + ADDQ AX, SI // r00 += rax + ADCQ DX, BP // r01 += rdx + + MOVQ 32(BX), DX // rdx <-- x4 + IMUL3Q $19, DX, AX // rax <-- x4*19 + MULQ 8(CX) // rdx rax <-- x4_19*y1 + ADDQ AX, SI // r00 += rax + ADCQ DX, BP // r01 += rdx + + // Calculate r1 + MOVQ 0(BX), AX + MULQ 8(CX) + MOVQ AX, R8 // r10 + MOVQ DX, R9 // r11 + + MOVQ 8(BX), AX + MULQ 0(CX) + ADDQ AX, R8 + ADCQ DX, R9 + + MOVQ 16(BX), DX + IMUL3Q $19, DX, AX + MULQ 32(CX) + ADDQ AX, R8 + ADCQ DX, R9 + + MOVQ 24(BX), DX + IMUL3Q $19, DX, AX + MULQ 24(CX) + ADDQ AX, R8 + ADCQ DX, R9 + + MOVQ 32(BX), DX + IMUL3Q $19, DX, AX + MULQ 16(CX) + ADDQ AX, R8 + ADCQ DX, R9 + + // Calculate r2 + MOVQ 0(BX), AX + MULQ 16(CX) + MOVQ AX, R10 // r20 + MOVQ DX, R11 // r21 + + MOVQ 8(BX), AX + MULQ 8(CX) + ADDQ AX, R10 + ADCQ DX, R11 + + MOVQ 16(BX), AX + MULQ 0(CX) + ADDQ AX, R10 + ADCQ DX, R11 + + MOVQ 24(BX), DX + IMUL3Q $19, DX, AX + MULQ 32(CX) + ADDQ AX, R10 + ADCQ DX, R11 + + MOVQ 32(BX), DX + IMUL3Q $19, DX, AX + MULQ 24(CX) + ADDQ AX, R10 + ADCQ DX, R11 + + // Calculate r3 + MOVQ 0(BX), AX + MULQ 24(CX) + MOVQ AX, R12 // r30 + MOVQ DX, R13 // r31 + + MOVQ 8(BX), AX + MULQ 16(CX) + ADDQ AX, R12 + ADCQ DX, R13 + + MOVQ 16(BX), AX + MULQ 8(CX) + ADDQ AX, R12 + ADCQ DX, R13 + + MOVQ 24(BX), AX + MULQ 0(CX) + ADDQ AX, R12 + ADCQ DX, R13 + + MOVQ 32(BX), DX + IMUL3Q $19, DX, AX + MULQ 32(CX) + ADDQ AX, R12 + ADCQ DX, R13 + + // Calculate r4 + MOVQ 0(BX), AX + MULQ 32(CX) + MOVQ AX, R14 // r40 + MOVQ DX, R15 // r41 + + MOVQ 8(BX), AX + MULQ 24(CX) + ADDQ AX, R14 + ADCQ DX, R15 + + MOVQ 16(BX), AX + MULQ 16(CX) + ADDQ AX, R14 + ADCQ DX, R15 + + MOVQ 24(BX), AX + MULQ 8(CX) + ADDQ AX, R14 + ADCQ DX, R15 + + MOVQ 32(BX), AX + MULQ 0(CX) + ADDQ AX, R14 + ADCQ DX, R15 + + + MOVQ $2251799813685247, AX // (1<<51) - 1 + SHLQ $13, SI, BP // r01 = shld with r00 + ANDQ AX, SI // r00 &= mask51 + SHLQ $13, R8, R9 // r11 = shld with r10 + ANDQ AX, R8 // r10 &= mask51 + ADDQ BP, R8 // r10 += r01 + SHLQ $13, R10, R11 // r21 = shld with r20 + ANDQ AX, R10 // r20 &= mask51 + ADDQ R9, R10 // r20 += r11 + SHLQ $13, R12, R13 // r31 = shld with r30 + ANDQ AX, R12 // r30 &= mask51 + ADDQ R11, R12 // r30 += r21 + SHLQ $13, R14, R15 // r41 = shld with r40 + ANDQ AX, R14 // r40 &= mask51 + ADDQ R13, R14 // r40 += r31 + IMUL3Q $19, R15, R15 // r41 = r41*19 + ADDQ R15, SI // r00 += r41 + + MOVQ SI, DX // rdx <-- r00 + SHRQ $51, DX // rdx <-- r00 >> 51 + ADDQ DX, R8 // r10 += r00 >> 51 + MOVQ R8, DX // rdx <-- r10 + SHRQ $51, DX // rdx <-- r10 >> 51 + ANDQ AX, SI // r00 &= mask51 + ADDQ DX, R10 // r20 += r10 >> 51 + MOVQ R10, DX // rdx <-- r20 + SHRQ $51, DX // rdx <-- r20 >> 51 + ANDQ AX, R8 // r10 &= mask51 + ADDQ DX, R12 // r30 += r20 >> 51 + MOVQ R12, DX // rdx <-- r30 + SHRQ $51, DX // rdx <-- r30 >> 51 + ANDQ AX, R10 // r20 &= mask51 + ADDQ DX, R14 // r40 += r30 >> 51 + MOVQ R14, DX // rdx <-- r40 + SHRQ $51, DX // rdx <-- r40 >> 51 + ANDQ AX, R12 // r30 &= mask51 + IMUL3Q $19, DX, DX // rdx <-- (r40 >> 51) * 19 + ADDQ DX, SI // r00 += (r40 >> 51) *19 + ANDQ AX, R14 // r40 &= mask51 + + MOVQ SI, 0(DI) + MOVQ R8, 8(DI) + MOVQ R10, 16(DI) + MOVQ R12, 24(DI) + MOVQ R14, 32(DI) + RET diff --git a/internal/radix51/fe_square.go b/internal/radix51/fe_square.go new file mode 100644 index 0000000..fa738eb --- /dev/null +++ b/internal/radix51/fe_square.go @@ -0,0 +1,94 @@ +// +build !amd64 noasm + +package radix51 + +// FeSquare sets out = x*x +func FeSquare(out, x *FieldElement) { + // Squaring needs only 15 mul instructions. Some inputs are multiplied by 2; + // this is combined with multiplication by 19 where possible. The coefficient + // reduction after squaring is the same as for multiplication. + + var x0, x1, x2, x3, x4 uint64 + + x0 = x[0] + x1 = x[1] + x2 = x[2] + x3 = x[3] + x4 = x[4] + + x0_2 := x0 << 1 + x1_2 := x1 << 1 + + x1_38 := x1 * 38 + x2_38 := x2 * 38 + x3_38 := x3 * 38 + + x3_19 := x3 * 19 + x4_19 := x4 * 19 + + // r0 = x0*x0 + x1*38*x4 + x2*38*x3 + r00, r01 := mul64x64(0, 0, x0, x0) + r00, r01 = mul64x64(r00, r01, x1_38, x4) + r00, r01 = mul64x64(r00, r01, x2_38, x3) + + // r1 = x0*2*x1 + x2*38*x4 + x3*19*x3 + r10, r11 := mul64x64(0, 0, x0_2, x1) + r10, r11 = mul64x64(r10, r11, x2_38, x4) + r10, r11 = mul64x64(r10, r11, x3_19, x3) + + // r2 = x0*2*x2 + x1*x1 + x3*38*x4 + r20, r21 := mul64x64(0, 0, x0_2, x2) + r20, r21 = mul64x64(r20, r21, x1, x1) + r20, r21 = mul64x64(r20, r21, x3_38, x4) + + // r3 = x0*2*x3 + x1*2*x2 + x4*19*x4 + r30, r31 := mul64x64(0, 0, x0_2, x3) + r30, r31 = mul64x64(r30, r31, x1_2, x2) + r30, r31 = mul64x64(r30, r31, x4_19, x4) + + // r4 = x0*2*x4 + x1*2*x3 + x2*x2 + r40, r41 := mul64x64(0, 0, x0_2, x4) + r40, r41 = mul64x64(r40, r41, x1_2, x3) + r40, r41 = mul64x64(r40, r41, x2, x2) + + // Same reduction + + r01 = (r01 << 13) | (r00 >> 51) + r00 &= maskLow51Bits + + r11 = (r11 << 13) | (r10 >> 51) + r10 &= maskLow51Bits + r10 += r01 + + r21 = (r21 << 13) | (r20 >> 51) + r20 &= maskLow51Bits + r20 += r11 + + r31 = (r31 << 13) | (r30 >> 51) + r30 &= maskLow51Bits + r30 += r21 + + r41 = (r41 << 13) | (r40 >> 51) + r40 &= maskLow51Bits + r40 += r31 + + r41 *= 19 + r00 += r41 + + r10 += r00 >> 51 + r00 &= maskLow51Bits + r20 += r10 >> 51 + r10 &= maskLow51Bits + r30 += r20 >> 51 + r20 &= maskLow51Bits + r40 += r30 >> 51 + r30 &= maskLow51Bits + r00 += (r40 >> 51) * 19 + r40 &= maskLow51Bits + + out[0] = r00 + out[1] = r10 + out[2] = r20 + out[3] = r30 + out[4] = r40 +} diff --git a/internal/radix51/fe_square_amd64.go b/internal/radix51/fe_square_amd64.go new file mode 100644 index 0000000..c089146 --- /dev/null +++ b/internal/radix51/fe_square_amd64.go @@ -0,0 +1,6 @@ +// +build amd64,!noasm + +package radix51 + +// go:noescape +func FeSquare(out, x *FieldElement) diff --git a/internal/radix51/fe_square_amd64.s b/internal/radix51/fe_square_amd64.s new file mode 100644 index 0000000..f9af5c8 --- /dev/null +++ b/internal/radix51/fe_square_amd64.s @@ -0,0 +1,146 @@ +// +build amd64,!noasm + +// func FeSquare(outp *uint64, xp *uint64) +TEXT ·FeSquare(SB),4,$0-16 + MOVQ outp+0(FP), DI + MOVQ xp+8(FP), SI + + // r0 = x0*x0 + x1*38*x4 + x2*38*x3 + MOVQ 0(SI), AX + MULQ 0(SI) + MOVQ AX, CX // r00 + MOVQ DX, R8 // r01 + + MOVQ 8(SI), DX + IMUL3Q $38, DX, AX + MULQ 32(SI) + ADDQ AX, CX + ADCQ DX, R8 + + MOVQ 16(SI), DX + IMUL3Q $38, DX, AX + MULQ 24(SI) + ADDQ AX, CX + ADCQ DX, R8 + + // r1 = x0*2*x1 + x2*38*x4 + x3*19*x3 + MOVQ 0(SI), AX + SHLQ $1, AX + MULQ 8(SI) + MOVQ AX, R9 // r10 + MOVQ DX, R10 // r11 + + MOVQ 16(SI), DX + IMUL3Q $38, DX, AX + MULQ 32(SI) + ADDQ AX, R9 + ADCQ DX, R10 + + MOVQ 24(SI), DX + IMUL3Q $19, DX, AX + MULQ 24(SI) + ADDQ AX, R9 + ADCQ DX, R10 + + // r2 = x0*2*x2 + x1*x1 + x3*38*x4 + MOVQ 0(SI), AX + SHLQ $1, AX + MULQ 16(SI) + MOVQ AX, R11 // r20 + MOVQ DX, R12 // r21 + + MOVQ 8(SI), AX + MULQ 8(SI) + ADDQ AX, R11 + ADCQ DX, R12 + + MOVQ 24(SI), DX + IMUL3Q $38, DX, AX + MULQ 32(SI) + ADDQ AX, R11 + ADCQ DX, R12 + + // r3 = x0*2*x3 + x1*2*x2 + x4*19*x4 + MOVQ 0(SI), AX + SHLQ $1, AX + MULQ 24(SI) + MOVQ AX, R13 // r30 + MOVQ DX, R14 // r31 + + MOVQ 8(SI), AX + SHLQ $1, AX + MULQ 16(SI) + ADDQ AX, R13 + ADCQ DX, R14 + + MOVQ 32(SI), DX + IMUL3Q $19, DX, AX + MULQ 32(SI) + ADDQ AX, R13 + ADCQ DX, R14 + + // r4 = x0*2*x4 + x1*2*x3 + x2*x2 + MOVQ 0(SI), AX + SHLQ $1, AX + MULQ 32(SI) + MOVQ AX, R15 // r40 + MOVQ DX, BX // r41 + + MOVQ 8(SI), AX + SHLQ $1, AX + MULQ 24(SI) + ADDQ AX, R15 + ADCQ DX, BX + + MOVQ 16(SI), AX + MULQ 16(SI) + ADDQ AX, R15 + ADCQ DX, BX + + // Reduce + MOVQ $2251799813685247, AX // (1<<51) - 1 + SHLQ $13, CX, R8 // r01 = shld with r00 + ANDQ AX, CX // r00 &= mask51 + SHLQ $13, R9, R10 // r11 = shld with r10 + ANDQ AX, R9 // r10 &= mask51 + ADDQ R8, R9 // r10 += r01 + SHLQ $13, R11, R12 // r21 = shld with r20 + ANDQ AX, R11 // r20 &= mask51 + ADDQ R10, R11 // r20 += r11 + SHLQ $13, R13, R14 // r31 = shld with r30 + ANDQ AX, R13 // r30 &= mask51 + ADDQ R12, R13 // r30 += r21 + SHLQ $13, R15, BX // r41 = shld with r40 + ANDQ AX, R15 // r40 &= mask51 + ADDQ R14, R15 // r40 += r31 + IMUL3Q $19, BX, DX // r41 = r41*19 + ADDQ DX, CX // r00 += r41 + + MOVQ CX, DX // rdx <-- r00 + SHRQ $51, DX // rdx <-- r00 >> 51 + ADDQ DX, R9 // r10 += r00 >> 51 + MOVQ R9, DX // rdx <-- r10 + SHRQ $51, DX // rdx <-- r10 >> 51 + ANDQ AX, CX // r00 &= mask51 + ADDQ DX, R11 // r20 += r10 >> 51 + MOVQ R11, DX // rdx <-- r20 + SHRQ $51, DX // rdx <-- r20 >> 51 + ANDQ AX, R9 // r10 &= mask51 + ADDQ DX, R13 // r30 += r20 >> 51 + MOVQ R13, DX // rdx <-- r30 + SHRQ $51, DX // rdx <-- r30 >> 51 + ANDQ AX, R11 // r20 &= mask51 + ADDQ DX, R15 // r40 += r30 >> 51 + MOVQ R15, DX // rdx <-- r40 + SHRQ $51, DX // rdx <-- r40 >> 51 + ANDQ AX, R13 // r30 &= mask51 + IMUL3Q $19, DX, DX // rdx <-- (r40 >> 51) * 19 + ADDQ DX, CX // r00 += (r40 >> 51) *19 + ANDQ AX, R15 // r40 &= mask51 + + MOVQ CX, 0(DI) + MOVQ R9, 8(DI) + MOVQ R11, 16(DI) + MOVQ R13, 24(DI) + MOVQ R15, 32(DI) + RET diff --git a/internal/radix51/fe_test.go b/internal/radix51/fe_test.go new file mode 100644 index 0000000..217d910 --- /dev/null +++ b/internal/radix51/fe_test.go @@ -0,0 +1,175 @@ +package radix51 + +import ( + "bytes" + "crypto/rand" + "io" + "testing" + "unsafe" +) + +func TestMul64to128(t *testing.T) { + a := uint64(5) + b := uint64(5) + r0, r1 := mul64x64(0, 0, a, b) + if r0 != 0x19 || r1 != 0 { + t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r0, r1) + } + + a = uint64(18014398509481983) // 2^54 - 1 + b = uint64(18014398509481983) // 2^54 - 1 + r0, r1 = mul64x64(0, 0, a, b) + if r0 != 0xff80000000000001 || r1 != 0xfffffffffff { + t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r0, r1) + } + + a = uint64(1125899906842661) + b = uint64(2097155) + r0, r1 = mul64x64(0, 0, a, b) + r0, r1 = mul64x64(r0, r1, a, b) + r0, r1 = mul64x64(r0, r1, a, b) + r0, r1 = mul64x64(r0, r1, a, b) + r0, r1 = mul64x64(r0, r1, a, b) + if r0 != 16888498990613035 || r1 != 640 { + t.Errorf("wrong answer: %d + %d*(2**64)", r0, r1) + } +} + +func BenchmarkWideMultInline(t *testing.B) { + var r0, r1, ol, oh uint64 + a := uint64(18014398509481983) // 2^54 - 1 + b := uint64(18014398509481983) // 2^54 - 1 + + for i := 0; i < t.N; i++ { + t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32) + t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF) + ol = (a * b) + r0 + cmp := ol < r0 + oh = r1 + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp))) + + r1 = oh + r0 = ol + } +} + +func BenchmarkWideMultCall(t *testing.B) { + var r0, r1 uint64 + a := uint64(18014398509481983) + b := uint64(18014398509481983) + + for i := 0; i < t.N; i++ { + r0, r1 = mul64x64(r0, r1, a, b) + } +} + +func TestFeFromBytesRoundTrip(t *testing.T) { + var in, out [32]byte + var fe, r FieldElement + + in = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + + FeFromBytes(&fe, &in) + FeToBytes(&out, &fe) + + if !bytes.Equal(in[:], out[:]) { + t.Error("Bytes<>FE doesn't roundtrip") + } + + // Random field element + fe[0] = 0x4e645be9215a2 + fe[1] = 0x4e9654922df12 + fe[2] = 0x5829e468b0205 + fe[3] = 0x5e8fca9e0881c + fe[4] = 0x5c490f087d796 + + FeToBytes(&out, &fe) + FeFromBytes(&r, &out) + + for i := 0; i < len(fe); i++ { + if r[i] != fe[i] { + t.Error("FE<>Bytes doesn't roundtrip") + } + } +} + +// Tests self-consistency between FeMul and FeSquare. +func TestSanity(t *testing.T) { + var x FieldElement + var x2, x2sq FieldElement + // var x2Go, x2sqGo FieldElement + + x = [5]uint64{1, 1, 1, 1, 1} + FeMul(&x2, &x, &x) + // FeMulGo(&x2Go, &x, &x) + FeSquare(&x2sq, &x) + // FeSquareGo(&x2sqGo, &x) + + // if !vartimeEqual(x2, x2Go) || !vartimeEqual(x2sq, x2sqGo) || !vartimeEqual(x2, x2sq) { + // t.Fatalf("all ones failed\nmul.s: %d\nmul.g: %d\nsqr.s: %d\nsqr.g: %d\n", x2, x2Go, x2sq, x2sqGo) + // } + + if !vartimeEqual(x2, x2sq) { + t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) + } + + var bytes [32]byte + + _, err := io.ReadFull(rand.Reader, bytes[:]) + if err != nil { + t.Fatal(err) + } + FeFromBytes(&x, &bytes) + + FeMul(&x2, &x, &x) + // FeMulGo(&x2Go, &x, &x) + FeSquare(&x2sq, &x) + // FeSquareGo(&x2sqGo, &x) + + // if !vartimeEqual(x2, x2Go) || !vartimeEqual(x2sq, x2sqGo) || !vartimeEqual(x2, x2sq) { + // t.Fatalf("random field element failed\nfe: %x\n\nmul.s: %x\nmul.g: %x\nsqr.s: %x\nsqr.g: %x\n", x, x2, x2Go, x2sq, x2sqGo) + // } + + if !vartimeEqual(x2, x2sq) { + t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) + } +} + +func vartimeEqual(x, y FieldElement) bool { + for i := 0; i < 5; i++ { + if x[i] != y[i] { + return false + } + } + return true +} + +func TestFeInvert(t *testing.T) { + var x FieldElement = [5]uint64{1, 1, 1, 1, 1} + var one FieldElement = [5]uint64{1, 0, 0, 0, 0} + var xinv, r FieldElement + + FeInvert(&xinv, &x) + FeMul(&r, &x, &xinv) + FeReduce(&r, &r) + + if !vartimeEqual(one, r) { + t.Errorf("inversion identity failed, got: %x", r) + } + + var bytes [32]byte + + _, err := io.ReadFull(rand.Reader, bytes[:]) + if err != nil { + t.Fatal(err) + } + FeFromBytes(&x, &bytes) + + FeInvert(&xinv, &x) + FeMul(&r, &x, &xinv) + FeReduce(&r, &r) + + if !vartimeEqual(one, r) { + t.Errorf("random inversion identity failed, got: %x for field element %x", r, x) + } +} diff --git a/internal/radix51/mul.go b/internal/radix51/mul.go new file mode 100644 index 0000000..dbea36c --- /dev/null +++ b/internal/radix51/mul.go @@ -0,0 +1,14 @@ +package radix51 + +import "unsafe" + +// mul64x64 multiples two 64-bit numbers and adds them to two accumulators. +// This function is written to ensure it inlines. I am so sorry. +func mul64x64(lo, hi, a, b uint64) (ol uint64, oh uint64) { + t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32) + t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF) + ol = (a * b) + lo + cmp := ol < lo + oh = hi + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp))) + return +}