stock multiplication seems to work. TODO: fuzz, wire in reduction and 51-bit packing

This commit is contained in:
George Tankersley 2018-05-28 00:00:00 +00:00
parent a4358ce581
commit 9ca546367b
4 changed files with 40 additions and 17 deletions

View File

@ -377,6 +377,7 @@ var radix51A = field.FieldElement{
func BenchmarkFeMul51(b *testing.B) {
var h field.FieldElement
b.ResetTimer()
for i := 0; i < b.N; i++ {
field.FeMul(&h, &radix51A, &radix51A)
}
@ -384,6 +385,7 @@ func BenchmarkFeMul51(b *testing.B) {
func BenchmarkFeMulADX(b *testing.B) {
var h [10]uint64
b.ResetTimer()
for i := 0; i < b.N; i++ {
field.FeMulADX(&h, &radix51A, &radix51A)
}

View File

@ -13,12 +13,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
MOVQ yp+16(FP), BP
// The first diagonal sets up the accumulators.
XORQ AX,AX
MOVQ 0(BP), DX // rdx <-- y0
MULXQ 0(SI), R8, R9 // r0 <-- x0*y0
MOVQ R8, 0(DI)
MULXQ 8(SI), R10, R11 // r1 <-- x1*y0
ADDQ R9, R10
MULXQ 16(SI), R12, R13 // r2 <-- x2*y0
@ -28,6 +27,10 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
MULXQ 32(SI), BX, CX // r4 <-- x4*y0
ADCQ R15, BX
// Consolidate carry chain into r5
ADCQ AX, CX
MOVQ CX, 40(DI)
MOVQ R10, 8(DI)
MOVQ R12, 16(DI)
MOVQ R14, 24(DI)
@ -36,11 +39,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
XORQ AX, AX // clear flags
MOVQ 8(BP), DX // rdx <-- y1
MULXQ 0(SI), R10, R11 // r1 <-- x0*y1
ADOXQ 8(DI), R10
MOVQ R10, 8(DI)
MULXQ 8(SI), R12, R13 // r2 <-- x1*y1
ADCXQ R11, R12
ADOXQ 16(DI), R12
@ -54,6 +55,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
ADCXQ CX, R8
ADOXQ 40(DI), R8
// Consolidate both carry chains into r6
ADCXQ AX, R9
ADOXQ AX, R9
MOVQ R9, 48(DI)
// Update accumulators
MOVQ R12, 16(DI)
MOVQ R14, 24(DI)
@ -63,11 +69,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
XORQ AX, AX // clear flags
MOVQ 16(BP), DX // rdx <-- y2
MULXQ 0(SI), R12, R13 // r2 <-- x0*y2
ADOXQ 16(DI), R12
MOVQ R12, 16(DI)
MULXQ 8(SI), R14, R15 // r3 <-- x1*y2
ADCXQ R13, R14
ADOXQ 24(DI), R14
@ -81,6 +85,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
ADCXQ R9, R10
ADOXQ 48(DI), R10
// Consolidate both carry chains into r7
ADCXQ AX, R11
ADOXQ AX, R11
MOVQ R11, 56(DI)
// Update accumulators
MOVQ R14, 24(DI)
MOVQ BX, 32(DI)
@ -90,11 +99,9 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
XORQ AX, AX // clear flags
MOVQ 24(BP), DX // rdx <-- y3
MULXQ 0(SI), R14, R15 // r3 <-- x0*y3
ADOXQ 24(DI), R14
MOVQ R14, 24(DI)
MULXQ 8(SI), BX, CX // r4 <-- x1*y3
ADCXQ R15, BX
ADOXQ 32(DI), BX
@ -108,6 +115,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
ADCXQ R11, R12
ADOXQ 56(DI), R12
// Consolidate both carry chains into r8
ADCXQ AX, R13
ADOXQ AX, R13
MOVQ R13, 64(DI)
// Update accumulators
MOVQ BX, 32(DI)
MOVQ R8, 40(DI)
@ -117,23 +129,18 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
XORQ AX, AX // clear flags
MOVQ 32(BP), DX // rdx <-- y4
MULXQ 0(SI), BX, CX // r4 <-- x0*y4
ADOXQ 32(DI), BX
MOVQ BX, 32(DI)
MULXQ 8(SI), R8, R9 // r5 <-- x1*y4
ADCXQ CX, R8
ADOXQ 40(DI), R8
MULXQ 16(SI), R10, R11 // r6 <-- x2*y4
ADCXQ R9, R10
ADOXQ 48(DI), R10
MULXQ 24(SI), R12, R13 // r7 <-- x3*y4
ADCXQ R11, R12
ADOXQ 56(DI), R12
MULXQ 32(SI), R14, R15 // r8 <-- x4*y4
ADCXQ R13, R14
ADOXQ 64(DI), R14
@ -141,11 +148,11 @@ TEXT ·FeMulADX(SB),NOSPLIT,$0
// Consolidate both carry chains in R15, our final output.
ADCXQ AX, R15
ADOXQ AX, R15
MOVQ R15, 72(DI)
MOVQ R8, 40(DI)
MOVQ R10, 48(DI)
MOVQ R12, 56(DI)
MOVQ R14, 64(DI)
MOVQ R15, 72(DI)
RET

View File

@ -195,9 +195,8 @@ func TestFeInvert(t *testing.T) {
}
func TestFeMulADX(t *testing.T) {
var x FieldElement = [5]uint64{1, 0, 0, 1, 16384}
// x[0] = 0x4e645be9215a2
var y FieldElement = [5]uint64{16384, 16384, 16384, 16384, 16384}
var x FieldElement = [5]uint64{140194305, 1, 2955487002624, 1, 2955487002624}
var y FieldElement = [5]uint64{16384, 16384, 16384, 16384, 2955487002624}
var z [10]uint64
FeMulADX(&z, &x, &y)

15
internal/radix51/test.py Normal file
View File

@ -0,0 +1,15 @@
t = [2**(64*t) for t in range(10)]
p = 2**255 - 19
def field(e):
acc = 0
for i in range(len(e)):
acc += e[i]*t[i]
return acc
x = field([140194305, 1, 2955487002624, 1, 2955487002624])
y = field([16384, 16384, 16384, 16384, 2955487002624])
z = field([2296943493120,2296943509504,48424995994501120,48424995994517504,8610922045896736768,96848353589002262,1217319147952685056,48425654538467760,1168896448901677056,473520])
print((x*y) % p == z % p)