stock multiplication seems to work. TODO: fuzz, wire in reduction and 51-bit packing

This commit is contained in:
George Tankersley 2018-05-28 00:00:00 +00:00
parent a4358ce581
commit 9ca546367b
4 changed files with 40 additions and 17 deletions

View File

@ -377,6 +377,7 @@ var radix51A = field.FieldElement{
func BenchmarkFeMul51(b *testing.B) { func BenchmarkFeMul51(b *testing.B) {
var h field.FieldElement var h field.FieldElement
b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
field.FeMul(&h, &radix51A, &radix51A) field.FeMul(&h, &radix51A, &radix51A)
} }
@ -384,6 +385,7 @@ func BenchmarkFeMul51(b *testing.B) {
func BenchmarkFeMulADX(b *testing.B) { func BenchmarkFeMulADX(b *testing.B) {
var h [10]uint64 var h [10]uint64
b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
field.FeMulADX(&h, &radix51A, &radix51A) field.FeMulADX(&h, &radix51A, &radix51A)
} }

View File

@ -13,12 +13,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
MOVQ yp+16(FP), BP MOVQ yp+16(FP), BP
// The first diagonal sets up the accumulators. // The first diagonal sets up the accumulators.
XORQ AX,AX
MOVQ 0(BP), DX // rdx <-- y0 MOVQ 0(BP), DX // rdx <-- y0
MULXQ 0(SI), R8, R9 // r0 <-- x0*y0 MULXQ 0(SI), R8, R9 // r0 <-- x0*y0
MOVQ R8, 0(DI) MOVQ R8, 0(DI)
MULXQ 8(SI), R10, R11 // r1 <-- x1*y0 MULXQ 8(SI), R10, R11 // r1 <-- x1*y0
ADDQ R9, R10 ADDQ R9, R10
MULXQ 16(SI), R12, R13 // r2 <-- x2*y0 MULXQ 16(SI), R12, R13 // r2 <-- x2*y0
@ -28,6 +27,10 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
MULXQ 32(SI), BX, CX // r4 <-- x4*y0 MULXQ 32(SI), BX, CX // r4 <-- x4*y0
ADCQ R15, BX ADCQ R15, BX
// Consolidate carry chain into r5
ADCQ AX, CX
MOVQ CX, 40(DI)
MOVQ R10, 8(DI) MOVQ R10, 8(DI)
MOVQ R12, 16(DI) MOVQ R12, 16(DI)
MOVQ R14, 24(DI) MOVQ R14, 24(DI)
@ -36,11 +39,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
XORQ AX, AX // clear flags XORQ AX, AX // clear flags
MOVQ 8(BP), DX // rdx <-- y1 MOVQ 8(BP), DX // rdx <-- y1
MULXQ 0(SI), R10, R11 // r1 <-- x0*y1 MULXQ 0(SI), R10, R11 // r1 <-- x0*y1
ADOXQ 8(DI), R10 ADOXQ 8(DI), R10
MOVQ R10, 8(DI) MOVQ R10, 8(DI)
MULXQ 8(SI), R12, R13 // r2 <-- x1*y1 MULXQ 8(SI), R12, R13 // r2 <-- x1*y1
ADCXQ R11, R12 ADCXQ R11, R12
ADOXQ 16(DI), R12 ADOXQ 16(DI), R12
@ -54,6 +55,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
ADCXQ CX, R8 ADCXQ CX, R8
ADOXQ 40(DI), R8 ADOXQ 40(DI), R8
// Consolidate both carry chains into r6
ADCXQ AX, R9
ADOXQ AX, R9
MOVQ R9, 48(DI)
// Update accumulators // Update accumulators
MOVQ R12, 16(DI) MOVQ R12, 16(DI)
MOVQ R14, 24(DI) MOVQ R14, 24(DI)
@ -63,11 +69,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
XORQ AX, AX // clear flags XORQ AX, AX // clear flags
MOVQ 16(BP), DX // rdx <-- y2 MOVQ 16(BP), DX // rdx <-- y2
MULXQ 0(SI), R12, R13 // r2 <-- x0*y2 MULXQ 0(SI), R12, R13 // r2 <-- x0*y2
ADOXQ 16(DI), R12 ADOXQ 16(DI), R12
MOVQ R12, 16(DI) MOVQ R12, 16(DI)
MULXQ 8(SI), R14, R15 // r3 <-- x1*y2 MULXQ 8(SI), R14, R15 // r3 <-- x1*y2
ADCXQ R13, R14 ADCXQ R13, R14
ADOXQ 24(DI), R14 ADOXQ 24(DI), R14
@ -81,6 +85,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
ADCXQ R9, R10 ADCXQ R9, R10
ADOXQ 48(DI), R10 ADOXQ 48(DI), R10
// Consolidate both carry chains into r7
ADCXQ AX, R11
ADOXQ AX, R11
MOVQ R11, 56(DI)
// Update accumulators // Update accumulators
MOVQ R14, 24(DI) MOVQ R14, 24(DI)
MOVQ BX, 32(DI) MOVQ BX, 32(DI)
@ -90,11 +99,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
XORQ AX, AX // clear flags XORQ AX, AX // clear flags
MOVQ 24(BP), DX // rdx <-- y3 MOVQ 24(BP), DX // rdx <-- y3
MULXQ 0(SI), R14, R15 // r3 <-- x0*y3 MULXQ 0(SI), R14, R15 // r3 <-- x0*y3
ADOXQ 24(DI), R14 ADOXQ 24(DI), R14
MOVQ R14, 24(DI) MOVQ R14, 24(DI)
MULXQ 8(SI), BX, CX // r4 <-- x1*y3 MULXQ 8(SI), BX, CX // r4 <-- x1*y3
ADCXQ R15, BX ADCXQ R15, BX
ADOXQ 32(DI), BX ADOXQ 32(DI), BX
@ -108,6 +115,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
ADCXQ R11, R12 ADCXQ R11, R12
ADOXQ 56(DI), R12 ADOXQ 56(DI), R12
// Consolidate both carry chains into r8
ADCXQ AX, R13
ADOXQ AX, R13
MOVQ R13, 64(DI)
// Update accumulators // Update accumulators
MOVQ BX, 32(DI) MOVQ BX, 32(DI)
MOVQ R8, 40(DI) MOVQ R8, 40(DI)
@ -117,23 +129,18 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
XORQ AX, AX // clear flags XORQ AX, AX // clear flags
MOVQ 32(BP), DX // rdx <-- y4 MOVQ 32(BP), DX // rdx <-- y4
MULXQ 0(SI), BX, CX // r4 <-- x0*y4 MULXQ 0(SI), BX, CX // r4 <-- x0*y4
ADOXQ 32(DI), BX ADOXQ 32(DI), BX
MOVQ BX, 32(DI) MOVQ BX, 32(DI)
MULXQ 8(SI), R8, R9 // r5 <-- x1*y4 MULXQ 8(SI), R8, R9 // r5 <-- x1*y4
ADCXQ CX, R8 ADCXQ CX, R8
ADOXQ 40(DI), R8 ADOXQ 40(DI), R8
MULXQ 16(SI), R10, R11 // r6 <-- x2*y4 MULXQ 16(SI), R10, R11 // r6 <-- x2*y4
ADCXQ R9, R10 ADCXQ R9, R10
ADOXQ 48(DI), R10 ADOXQ 48(DI), R10
MULXQ 24(SI), R12, R13 // r7 <-- x3*y4 MULXQ 24(SI), R12, R13 // r7 <-- x3*y4
ADCXQ R11, R12 ADCXQ R11, R12
ADOXQ 56(DI), R12 ADOXQ 56(DI), R12
MULXQ 32(SI), R14, R15 // r8 <-- x4*y4 MULXQ 32(SI), R14, R15 // r8 <-- x4*y4
ADCXQ R13, R14 ADCXQ R13, R14
ADOXQ 64(DI), R14 ADOXQ 64(DI), R14
@ -141,11 +148,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
// Consolidate both carry chains in R15, our final output. // Consolidate both carry chains in R15, our final output.
ADCXQ AX, R15 ADCXQ AX, R15
ADOXQ AX, R15 ADOXQ AX, R15
MOVQ R15, 72(DI)
MOVQ R8, 40(DI) MOVQ R8, 40(DI)
MOVQ R10, 48(DI) MOVQ R10, 48(DI)
MOVQ R12, 56(DI) MOVQ R12, 56(DI)
MOVQ R14, 64(DI) MOVQ R14, 64(DI)
MOVQ R15, 72(DI)
RET RET

View File

@ -195,9 +195,8 @@ func TestFeInvert(t *testing.T) {
} }
func TestFeMulADX(t *testing.T) { func TestFeMulADX(t *testing.T) {
var x FieldElement = [5]uint64{1, 0, 0, 1, 16384} var x FieldElement = [5]uint64{140194305, 1, 2955487002624, 1, 2955487002624}
// x[0] = 0x4e645be9215a2 var y FieldElement = [5]uint64{16384, 16384, 16384, 16384, 2955487002624}
var y FieldElement = [5]uint64{16384, 16384, 16384, 16384, 16384}
var z [10]uint64 var z [10]uint64
FeMulADX(&z, &x, &y) FeMulADX(&z, &x, &y)

15
internal/radix51/test.py Normal file
View File

@ -0,0 +1,15 @@
t = [2**(64*t) for t in range(10)]
p = 2**255 - 19
def field(e):
acc = 0
for i in range(len(e)):
acc += e[i]*t[i]
return acc
x = field([140194305, 1, 2955487002624, 1, 2955487002624])
y = field([16384, 16384, 16384, 16384, 2955487002624])
z = field([2296943493120,2296943509504,48424995994501120,48424995994517504,8610922045896736768,96848353589002262,1217319147952685056,48425654538467760,1168896448901677056,473520])
print((x*y) % p == z % p)