From dbb69fe1627a547ec909250a8d69e3a074dd24dd Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Tue, 9 Feb 2021 14:46:57 +0000 Subject: [PATCH] internal/radix51: Use BMI2 to accelerate FieldElement.Square MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``` name old time/op new time/op delta Square-8 29.0ns ± 2% 26.5ns ± 0% -8.63% (p=0.008 n=5+5) ``` --- internal/radix51/fe_amd64.go | 6 ++- internal/radix51/fe_amd64.s | 84 +++++++++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/internal/radix51/fe_amd64.go b/internal/radix51/fe_amd64.go index dbe7fc7..b23e6ba 100644 --- a/internal/radix51/fe_amd64.go +++ b/internal/radix51/fe_amd64.go @@ -22,12 +22,16 @@ var useBMI2 bool func feMulAmd64(out, a, b *FieldElement, useBMI2 bool) //go:noescape -func feSquare(out, x *FieldElement) +func feSquareAmd64(out, x *FieldElement, useBMI2 bool) func feMul(out, a, b *FieldElement) { feMulAmd64(out, a, b, useBMI2) } +func feSquare(out, x *FieldElement) { + feSquareAmd64(out, x, useBMI2) +} + func init() { useBMI2 = cpu.Initialized && cpu.X86.HasBMI2 } diff --git a/internal/radix51/fe_amd64.s b/internal/radix51/fe_amd64.s index 7d01270..293fb04 100644 --- a/internal/radix51/fe_amd64.s +++ b/internal/radix51/fe_amd64.s @@ -327,10 +327,89 @@ mul_reduce: MOVQ R14, 32(DI) RET -// func feSquare(out, x *FieldElement) -TEXT ·feSquare(SB), NOSPLIT|NOFRAME, $0-16 +// func feSquareAmd64(out, x *FieldElement, useBMI2 bool) +TEXT ·feSquareAmd64(SB), NOSPLIT|NOFRAME, $0-17 MOVQ x+8(FP), BX + // Pick the appropriate implementation, based on if the caller thinks + // BMI2 is supported or not. + MOVBQZX useBMI2+16(FP), DX + TESTQ DX, DX + JZ square_vanilla + + // This codepath uses BMI2 to shave off a number of instructions, + // for a slight performance gain. + + // r0 = a0*a0 + // r1 = 2*a0*a1 + // r2 = 2*a0*a2 + // r3 = 2*a0*a3 + // r4 = 2*a0*a4 + MOVQ 0(BX), DX // rdx <- a0 + MULXQ DX, SI, BP // r00, r01 <- a0*a0 + SHLQ $1, DX // rdx *= 2 (d0 = 2*a0) + MULXQ 8(BX), R8, R9 // r10, r11 <- d0*a1 + MULXQ 16(BX), R10, R11 // r20, r21 <- d0*a2 + MULXQ 24(BX), R12, R13 // r30, r31 <- d0*a3 + MULXQ 32(BX), R14, R15 // r40, r41 <- d0*a4 + + // r2 += a1*a1 + // r3 += 2*a1*a2 + // r4 += 2*a1*a3 + MOVQ 8(BX), DX // rdx <- a1 + MULXQ DX, AX, DI // rdi, rax = a1*a1 + ADDQ AX, R10 // r20 += rax + ADCQ DI, R11 // r21 += rdi + cf + SHLQ $1, DX // rdx *= 2 (d1 = 2*a1) + MULXQ 16(BX), AX, DI // rdi, rax = d1*a2 + ADDQ AX, R12 // r30 += rax + ADCQ DI, R13 // r31 += rdi + cf + MULXQ 24(BX), AX, DI // rdi, rax = d1*a3 + ADDQ AX, R14 // r40 += rax + ADCQ DI, R15 // r41 += rdi + cf + IMUL3Q $19, DX, DX // rdx *= 19 (d1_38 = 2*19*a1) + MULXQ 32(BX), AX, DI // rdi, rax = d1_38*a4 + ADDQ AX, SI // r00 += rax + ADCQ DI, BP // r01 += rdi + cf + + // r4 += a2*a2 + // r0 += 2*19*a2*a3 + // r1 += 2*19*a2*a4 + MOVQ 16(BX), DX // rdx <- a2 + MULXQ DX, AX, DI // rdi, rax = a2*a2 + ADDQ AX, R14 // r40 += rax + ADCQ DI, R15 // r41 += rdi + cf + IMUL3Q $38, DX, DX // dx *= 2*19 (d2_38 = 2*19*a2) + MULXQ 24(BX), AX, DI // rdi, rax = d2_38*a3 + ADDQ AX, SI // r00 += rax + ADCQ DI, BP // r01 += rdi + cf + MULXQ 32(BX), AX, DI // rdi, rax = d2_38*a4 + ADDQ AX, R8 // r10 += rax + ADCQ DI, R9 // r11 += rdi + cf + + // r1 += 19*a3*a3 + // r2 += 2*19*a4*a3 + MOVQ 24(BX), DX // rdx <- a3 + IMUL3Q $19, DX, DX // dx *= 19 (d3_19 = 19*a3) + MULXQ 24(BX), AX, DI // rdi, rax = d3_19*a3 + ADDQ AX, R8 // r10 += rax + ADCQ DI, R9 // r11 += rdi + cf + SHLQ $1, DX // rdx *= 2 (d3_38 = 2*19*a3) + MULXQ 32(BX), AX, DI // rdi, rax = d3_38*a4 + ADDQ AX, R10 // r20 += rax + ADCQ DI, R11 // r21 += rdi + cf + + // r3 += 19*a4*a4 + MOVQ 32(BX), DX // rdx <- a4 + IMUL3Q $19, DX, DX // dx *= 19 (d4_19 = 19*a4) + MULXQ 32(BX), AX, DI // rdi, rax = d4_19*a4 + ADDQ AX, R12 // r30 += rax + ADCQ DI, R13 // r31 += rdi + cf + + JMP square_reduce + +square_vanilla: + // r0 = x0*x0 + x1*38*x4 + x2*38*x3 MOVQ 0(BX), AX MULQ 0(BX) @@ -423,6 +502,7 @@ TEXT ·feSquare(SB), NOSPLIT|NOFRAME, $0-16 ADDQ AX, R14 ADCQ DX, R15 +square_reduce: // Reduce reduce64()