diff --git a/ed25519_test.go b/ed25519_test.go index d92b269..665e662 100644 --- a/ed25519_test.go +++ b/ed25519_test.go @@ -377,6 +377,7 @@ var radix51A = field.FieldElement{ func BenchmarkFeMul51(b *testing.B) { var h field.FieldElement + b.ResetTimer() for i := 0; i < b.N; i++ { field.FeMul(&h, &radix51A, &radix51A) } @@ -384,6 +385,7 @@ func BenchmarkFeMul51(b *testing.B) { func BenchmarkFeMulADX(b *testing.B) { var h [10]uint64 + b.ResetTimer() for i := 0; i < b.N; i++ { field.FeMulADX(&h, &radix51A, &radix51A) } diff --git a/internal/radix51/fe_mul_amd64_adx.s b/internal/radix51/fe_mul_amd64_adx.s index 7f2075e..75aa320 100644 --- a/internal/radix51/fe_mul_amd64_adx.s +++ b/internal/radix51/fe_mul_amd64_adx.s @@ -13,12 +13,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 MOVQ yp+16(FP), BP // The first diagonal sets up the accumulators. + XORQ AX,AX MOVQ 0(BP), DX // rdx <-- y0 - MULXQ 0(SI), R8, R9 // r0 <-- x0*y0 MOVQ R8, 0(DI) - MULXQ 8(SI), R10, R11 // r1 <-- x1*y0 ADDQ R9, R10 MULXQ 16(SI), R12, R13 // r2 <-- x2*y0 @@ -28,6 +27,10 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 MULXQ 32(SI), BX, CX // r4 <-- x4*y0 ADCQ R15, BX + // Consolidate carry chain into r5 + ADCQ AX, CX + MOVQ CX, 40(DI) + MOVQ R10, 8(DI) MOVQ R12, 16(DI) MOVQ R14, 24(DI) @@ -36,11 +39,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 XORQ AX, AX // clear flags MOVQ 8(BP), DX // rdx <-- y1 - MULXQ 0(SI), R10, R11 // r1 <-- x0*y1 ADOXQ 8(DI), R10 MOVQ R10, 8(DI) - MULXQ 8(SI), R12, R13 // r2 <-- x1*y1 ADCXQ R11, R12 ADOXQ 16(DI), R12 @@ -54,6 +55,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 ADCXQ CX, R8 ADOXQ 40(DI), R8 + // Consolidate both carry chains into r6 + ADCXQ AX, R9 + ADOXQ AX, R9 + MOVQ R9, 48(DI) + // Update accumulators MOVQ R12, 16(DI) MOVQ R14, 24(DI) @@ -63,11 +69,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 XORQ AX, AX // clear flags MOVQ 16(BP), DX // rdx <-- y2 - MULXQ 0(SI), R12, R13 // r2 <-- x0*y2 ADOXQ 16(DI), R12 MOVQ R12, 16(DI) - MULXQ 8(SI), R14, R15 // r3 <-- x1*y2 ADCXQ R13, R14 ADOXQ 24(DI), R14 @@ -81,6 +85,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 ADCXQ R9, R10 ADOXQ 48(DI), R10 + // Consolidate both carry chains into r7 + ADCXQ AX, R11 + ADOXQ AX, R11 + MOVQ R11, 56(DI) + // Update accumulators MOVQ R14, 24(DI) MOVQ BX, 32(DI) @@ -90,11 +99,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 XORQ AX, AX // clear flags MOVQ 24(BP), DX // rdx <-- y3 - MULXQ 0(SI), R14, R15 // r3 <-- x0*y3 ADOXQ 24(DI), R14 MOVQ R14, 24(DI) - MULXQ 8(SI), BX, CX // r4 <-- x1*y3 ADCXQ R15, BX ADOXQ 32(DI), BX @@ -108,6 +115,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 ADCXQ R11, R12 ADOXQ 56(DI), R12 + // Consolidate both carry chains into r8 + ADCXQ AX, R13 + ADOXQ AX, R13 + MOVQ R13, 64(DI) + // Update accumulators MOVQ BX, 32(DI) MOVQ R8, 40(DI) @@ -117,23 +129,18 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 XORQ AX, AX // clear flags MOVQ 32(BP), DX // rdx <-- y4 - MULXQ 0(SI), BX, CX // r4 <-- x0*y4 ADOXQ 32(DI), BX MOVQ BX, 32(DI) - MULXQ 8(SI), R8, R9 // r5 <-- x1*y4 ADCXQ CX, R8 ADOXQ 40(DI), R8 - MULXQ 16(SI), R10, R11 // r6 <-- x2*y4 ADCXQ R9, R10 ADOXQ 48(DI), R10 - MULXQ 24(SI), R12, R13 // r7 <-- x3*y4 ADCXQ R11, R12 ADOXQ 56(DI), R12 - MULXQ 32(SI), R14, R15 // r8 <-- x4*y4 ADCXQ R13, R14 ADOXQ 64(DI), R14 @@ -141,11 +148,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0 // Consolidate both carry chains in R15, our final output. ADCXQ AX, R15 ADOXQ AX, R15 + MOVQ R15, 72(DI) MOVQ R8, 40(DI) MOVQ R10, 48(DI) MOVQ R12, 56(DI) MOVQ R14, 64(DI) - MOVQ R15, 72(DI) RET diff --git a/internal/radix51/fe_test.go b/internal/radix51/fe_test.go index af86e74..17a2e99 100644 --- a/internal/radix51/fe_test.go +++ b/internal/radix51/fe_test.go @@ -195,9 +195,8 @@ func TestFeInvert(t *testing.T) { } func TestFeMulADX(t *testing.T) { - var x FieldElement = [5]uint64{1, 0, 0, 1, 16384} - // x[0] = 0x4e645be9215a2 - var y FieldElement = [5]uint64{16384, 16384, 16384, 16384, 16384} + var x FieldElement = [5]uint64{140194305, 1, 2955487002624, 1, 2955487002624} + var y FieldElement = [5]uint64{16384, 16384, 16384, 16384, 2955487002624} var z [10]uint64 FeMulADX(&z, &x, &y) diff --git a/internal/radix51/test.py b/internal/radix51/test.py new file mode 100644 index 0000000..26e5c6e --- /dev/null +++ b/internal/radix51/test.py @@ -0,0 +1,15 @@ +t = [2**(64*t) for t in range(10)] +p = 2**255 - 19 + +def field(e): + acc = 0 + for i in range(len(e)): + acc += e[i]*t[i] + return acc + +x = field([140194305, 1, 2955487002624, 1, 2955487002624]) +y = field([16384, 16384, 16384, 16384, 2955487002624]) + +z = field([2296943493120,2296943509504,48424995994501120,48424995994517504,8610922045896736768,96848353589002262,1217319147952685056,48425654538467760,1168896448901677056,473520]) + +print((x*y) % p == z % p)