diff --git a/internal/radix51/fe_mul.go b/internal/radix51/fe_mul.go index 08ffbb4..27e7f5b 100644 --- a/internal/radix51/fe_mul.go +++ b/internal/radix51/fe_mul.go @@ -34,39 +34,39 @@ func (v *FieldElement) Mul(x, y *FieldElement) *FieldElement { x4_19 := x4 * 19 // calculate r0 = x0*y0 + 19*(x1*y4 + x2*y3 + x3*y2 + x4*y1) - r00, r01 := mul64x64(0, 0, x0, y0) - r00, r01 = mul64x64(r00, r01, x1_19, y4) - r00, r01 = mul64x64(r00, r01, x2_19, y3) - r00, r01 = mul64x64(r00, r01, x3_19, y2) - r00, r01 = mul64x64(r00, r01, x4_19, y1) + r00, r01 := madd64(0, 0, x0, y0) + r00, r01 = madd64(r00, r01, x1_19, y4) + r00, r01 = madd64(r00, r01, x2_19, y3) + r00, r01 = madd64(r00, r01, x3_19, y2) + r00, r01 = madd64(r00, r01, x4_19, y1) // calculate r1 = x0*y1 + x1*y0 + 19*(x2*y4 + x3*y3 + x4*y2) - r10, r11 := mul64x64(0, 0, x0, y1) - r10, r11 = mul64x64(r10, r11, x1, y0) - r10, r11 = mul64x64(r10, r11, x2_19, y4) - r10, r11 = mul64x64(r10, r11, x3_19, y3) - r10, r11 = mul64x64(r10, r11, x4_19, y2) + r10, r11 := madd64(0, 0, x0, y1) + r10, r11 = madd64(r10, r11, x1, y0) + r10, r11 = madd64(r10, r11, x2_19, y4) + r10, r11 = madd64(r10, r11, x3_19, y3) + r10, r11 = madd64(r10, r11, x4_19, y2) // calculate r2 = x0*y2 + x1*y1 + x2*y0 + 19*(x3*y4 + x4*y3) - r20, r21 := mul64x64(0, 0, x0, y2) - r20, r21 = mul64x64(r20, r21, x1, y1) - r20, r21 = mul64x64(r20, r21, x2, y0) - r20, r21 = mul64x64(r20, r21, x3_19, y4) - r20, r21 = mul64x64(r20, r21, x4_19, y3) + r20, r21 := madd64(0, 0, x0, y2) + r20, r21 = madd64(r20, r21, x1, y1) + r20, r21 = madd64(r20, r21, x2, y0) + r20, r21 = madd64(r20, r21, x3_19, y4) + r20, r21 = madd64(r20, r21, x4_19, y3) // calculate r3 = x0*y3 + x1*y2 + x2*y1 + x3*y0 + 19*x4*y4 - r30, r31 := mul64x64(0, 0, x0, y3) - r30, r31 = mul64x64(r30, r31, x1, y2) - r30, r31 = mul64x64(r30, r31, x2, y1) - r30, r31 = mul64x64(r30, r31, x3, y0) - r30, r31 = mul64x64(r30, r31, x4_19, y4) + r30, r31 := madd64(0, 0, x0, y3) + r30, r31 = madd64(r30, r31, x1, y2) + r30, r31 = madd64(r30, r31, x2, y1) + r30, r31 = madd64(r30, r31, x3, y0) + r30, r31 = madd64(r30, r31, x4_19, y4) // calculate r4 = x0*y4 + x1*y3 + x2*y2 + x3*y1 + x4*y0 - r40, r41 := mul64x64(0, 0, x0, y4) - r40, r41 = mul64x64(r40, r41, x1, y3) - r40, r41 = mul64x64(r40, r41, x2, y2) - r40, r41 = mul64x64(r40, r41, x3, y1) - r40, r41 = mul64x64(r40, r41, x4, y0) + r40, r41 := madd64(0, 0, x0, y4) + r40, r41 = madd64(r40, r41, x1, y3) + r40, r41 = madd64(r40, r41, x2, y2) + r40, r41 = madd64(r40, r41, x3, y1) + r40, r41 = madd64(r40, r41, x4, y0) // After the multiplication we need to reduce (carry) the 5 coefficients to // obtain a result with coefficients that are at most slightly larger than diff --git a/internal/radix51/fe_square.go b/internal/radix51/fe_square.go index 0a2d4ac..46e01fb 100644 --- a/internal/radix51/fe_square.go +++ b/internal/radix51/fe_square.go @@ -31,29 +31,29 @@ func (v *FieldElement) Square(x *FieldElement) *FieldElement { x4_19 := x4 * 19 // r0 = x0*x0 + x1*38*x4 + x2*38*x3 - r00, r01 := mul64x64(0, 0, x0, x0) - r00, r01 = mul64x64(r00, r01, x1_38, x4) - r00, r01 = mul64x64(r00, r01, x2_38, x3) + r00, r01 := madd64(0, 0, x0, x0) + r00, r01 = madd64(r00, r01, x1_38, x4) + r00, r01 = madd64(r00, r01, x2_38, x3) // r1 = x0*2*x1 + x2*38*x4 + x3*19*x3 - r10, r11 := mul64x64(0, 0, x0_2, x1) - r10, r11 = mul64x64(r10, r11, x2_38, x4) - r10, r11 = mul64x64(r10, r11, x3_19, x3) + r10, r11 := madd64(0, 0, x0_2, x1) + r10, r11 = madd64(r10, r11, x2_38, x4) + r10, r11 = madd64(r10, r11, x3_19, x3) // r2 = x0*2*x2 + x1*x1 + x3*38*x4 - r20, r21 := mul64x64(0, 0, x0_2, x2) - r20, r21 = mul64x64(r20, r21, x1, x1) - r20, r21 = mul64x64(r20, r21, x3_38, x4) + r20, r21 := madd64(0, 0, x0_2, x2) + r20, r21 = madd64(r20, r21, x1, x1) + r20, r21 = madd64(r20, r21, x3_38, x4) // r3 = x0*2*x3 + x1*2*x2 + x4*19*x4 - r30, r31 := mul64x64(0, 0, x0_2, x3) - r30, r31 = mul64x64(r30, r31, x1_2, x2) - r30, r31 = mul64x64(r30, r31, x4_19, x4) + r30, r31 := madd64(0, 0, x0_2, x3) + r30, r31 = madd64(r30, r31, x1_2, x2) + r30, r31 = madd64(r30, r31, x4_19, x4) // r4 = x0*2*x4 + x1*2*x3 + x2*x2 - r40, r41 := mul64x64(0, 0, x0_2, x4) - r40, r41 = mul64x64(r40, r41, x1_2, x3) - r40, r41 = mul64x64(r40, r41, x2, x2) + r40, r41 := madd64(0, 0, x0_2, x4) + r40, r41 = madd64(r40, r41, x1_2, x3) + r40, r41 = madd64(r40, r41, x2, x2) // Same reduction diff --git a/internal/radix51/fe_test.go b/internal/radix51/fe_test.go index d7343c2..2ce6629 100644 --- a/internal/radix51/fe_test.go +++ b/internal/radix51/fe_test.go @@ -9,60 +9,42 @@ import ( "crypto/rand" "io" "testing" - "unsafe" ) func TestMul64to128(t *testing.T) { a := uint64(5) b := uint64(5) - r0, r1 := mul64x64(0, 0, a, b) + r0, r1 := madd64(0, 0, a, b) if r0 != 0x19 || r1 != 0 { t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r0, r1) } a = uint64(18014398509481983) // 2^54 - 1 b = uint64(18014398509481983) // 2^54 - 1 - r0, r1 = mul64x64(0, 0, a, b) + r0, r1 = madd64(0, 0, a, b) if r0 != 0xff80000000000001 || r1 != 0xfffffffffff { t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r0, r1) } a = uint64(1125899906842661) b = uint64(2097155) - r0, r1 = mul64x64(0, 0, a, b) - r0, r1 = mul64x64(r0, r1, a, b) - r0, r1 = mul64x64(r0, r1, a, b) - r0, r1 = mul64x64(r0, r1, a, b) - r0, r1 = mul64x64(r0, r1, a, b) + r0, r1 = madd64(0, 0, a, b) + r0, r1 = madd64(r0, r1, a, b) + r0, r1 = madd64(r0, r1, a, b) + r0, r1 = madd64(r0, r1, a, b) + r0, r1 = madd64(r0, r1, a, b) if r0 != 16888498990613035 || r1 != 640 { t.Errorf("wrong answer: %d + %d*(2**64)", r0, r1) } } -func BenchmarkWideMultInline(t *testing.B) { - var r0, r1, ol, oh uint64 - a := uint64(18014398509481983) // 2^54 - 1 - b := uint64(18014398509481983) // 2^54 - 1 - - for i := 0; i < t.N; i++ { - t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32) - t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF) - ol = (a * b) + r0 - cmp := ol < r0 - oh = r1 + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp))) - - r1 = oh - r0 = ol - } -} - func BenchmarkWideMultCall(t *testing.B) { var r0, r1 uint64 a := uint64(18014398509481983) b := uint64(18014398509481983) for i := 0; i < t.N; i++ { - r0, r1 = mul64x64(r0, r1, a, b) + r0, r1 = madd64(r0, r1, a, b) } } diff --git a/internal/radix51/mul_bits.go b/internal/radix51/mul_bits.go new file mode 100644 index 0000000..0965d03 --- /dev/null +++ b/internal/radix51/mul_bits.go @@ -0,0 +1,18 @@ +// Copyright (c) 2019 George Tankersley. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 + +package radix51 + +import "math/bits" + +// madd64 multiples two 64-bit numbers and adds them to a split 128-bit accumulator. +func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) { + oh, ol = bits.Mul64(a, b) + var c uint64 + ol, c = bits.Add64(ol, lo, 0) + oh, _ = bits.Add64(oh, hi, c) + return +} diff --git a/internal/radix51/mul.go b/internal/radix51/mul_compat.go similarity index 69% rename from internal/radix51/mul.go rename to internal/radix51/mul_compat.go index c196d9d..54427a9 100644 --- a/internal/radix51/mul.go +++ b/internal/radix51/mul_compat.go @@ -2,13 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +build !go1.12 + package radix51 import "unsafe" -// mul64x64 multiples two 64-bit numbers and adds them to two accumulators. -// This function is written to ensure it inlines. I am so sorry. -func mul64x64(lo, hi, a, b uint64) (ol uint64, oh uint64) { +// madd64 multiples two 64-bit numbers and adds them to a split 128-bit accumulator. +func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) { t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32) t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF) ol = (a * b) + lo