radix51: use go1.12 intrinsics for 128-bit multiplications

This commit is contained in:
George Tankersley 2019-02-19 21:45:46 +00:00 committed by George Tankersley
parent 6b0cdf071a
commit 4ac00273fd
5 changed files with 70 additions and 69 deletions

View File

@ -34,39 +34,39 @@ func (v *FieldElement) Mul(x, y *FieldElement) *FieldElement {
x4_19 := x4 * 19
// calculate r0 = x0*y0 + 19*(x1*y4 + x2*y3 + x3*y2 + x4*y1)
r00, r01 := mul64x64(0, 0, x0, y0)
r00, r01 = mul64x64(r00, r01, x1_19, y4)
r00, r01 = mul64x64(r00, r01, x2_19, y3)
r00, r01 = mul64x64(r00, r01, x3_19, y2)
r00, r01 = mul64x64(r00, r01, x4_19, y1)
r00, r01 := madd64(0, 0, x0, y0)
r00, r01 = madd64(r00, r01, x1_19, y4)
r00, r01 = madd64(r00, r01, x2_19, y3)
r00, r01 = madd64(r00, r01, x3_19, y2)
r00, r01 = madd64(r00, r01, x4_19, y1)
// calculate r1 = x0*y1 + x1*y0 + 19*(x2*y4 + x3*y3 + x4*y2)
r10, r11 := mul64x64(0, 0, x0, y1)
r10, r11 = mul64x64(r10, r11, x1, y0)
r10, r11 = mul64x64(r10, r11, x2_19, y4)
r10, r11 = mul64x64(r10, r11, x3_19, y3)
r10, r11 = mul64x64(r10, r11, x4_19, y2)
r10, r11 := madd64(0, 0, x0, y1)
r10, r11 = madd64(r10, r11, x1, y0)
r10, r11 = madd64(r10, r11, x2_19, y4)
r10, r11 = madd64(r10, r11, x3_19, y3)
r10, r11 = madd64(r10, r11, x4_19, y2)
// calculate r2 = x0*y2 + x1*y1 + x2*y0 + 19*(x3*y4 + x4*y3)
r20, r21 := mul64x64(0, 0, x0, y2)
r20, r21 = mul64x64(r20, r21, x1, y1)
r20, r21 = mul64x64(r20, r21, x2, y0)
r20, r21 = mul64x64(r20, r21, x3_19, y4)
r20, r21 = mul64x64(r20, r21, x4_19, y3)
r20, r21 := madd64(0, 0, x0, y2)
r20, r21 = madd64(r20, r21, x1, y1)
r20, r21 = madd64(r20, r21, x2, y0)
r20, r21 = madd64(r20, r21, x3_19, y4)
r20, r21 = madd64(r20, r21, x4_19, y3)
// calculate r3 = x0*y3 + x1*y2 + x2*y1 + x3*y0 + 19*x4*y4
r30, r31 := mul64x64(0, 0, x0, y3)
r30, r31 = mul64x64(r30, r31, x1, y2)
r30, r31 = mul64x64(r30, r31, x2, y1)
r30, r31 = mul64x64(r30, r31, x3, y0)
r30, r31 = mul64x64(r30, r31, x4_19, y4)
r30, r31 := madd64(0, 0, x0, y3)
r30, r31 = madd64(r30, r31, x1, y2)
r30, r31 = madd64(r30, r31, x2, y1)
r30, r31 = madd64(r30, r31, x3, y0)
r30, r31 = madd64(r30, r31, x4_19, y4)
// calculate r4 = x0*y4 + x1*y3 + x2*y2 + x3*y1 + x4*y0
r40, r41 := mul64x64(0, 0, x0, y4)
r40, r41 = mul64x64(r40, r41, x1, y3)
r40, r41 = mul64x64(r40, r41, x2, y2)
r40, r41 = mul64x64(r40, r41, x3, y1)
r40, r41 = mul64x64(r40, r41, x4, y0)
r40, r41 := madd64(0, 0, x0, y4)
r40, r41 = madd64(r40, r41, x1, y3)
r40, r41 = madd64(r40, r41, x2, y2)
r40, r41 = madd64(r40, r41, x3, y1)
r40, r41 = madd64(r40, r41, x4, y0)
// After the multiplication we need to reduce (carry) the 5 coefficients to
// obtain a result with coefficients that are at most slightly larger than

View File

@ -31,29 +31,29 @@ func (v *FieldElement) Square(x *FieldElement) *FieldElement {
x4_19 := x4 * 19
// r0 = x0*x0 + x1*38*x4 + x2*38*x3
r00, r01 := mul64x64(0, 0, x0, x0)
r00, r01 = mul64x64(r00, r01, x1_38, x4)
r00, r01 = mul64x64(r00, r01, x2_38, x3)
r00, r01 := madd64(0, 0, x0, x0)
r00, r01 = madd64(r00, r01, x1_38, x4)
r00, r01 = madd64(r00, r01, x2_38, x3)
// r1 = x0*2*x1 + x2*38*x4 + x3*19*x3
r10, r11 := mul64x64(0, 0, x0_2, x1)
r10, r11 = mul64x64(r10, r11, x2_38, x4)
r10, r11 = mul64x64(r10, r11, x3_19, x3)
r10, r11 := madd64(0, 0, x0_2, x1)
r10, r11 = madd64(r10, r11, x2_38, x4)
r10, r11 = madd64(r10, r11, x3_19, x3)
// r2 = x0*2*x2 + x1*x1 + x3*38*x4
r20, r21 := mul64x64(0, 0, x0_2, x2)
r20, r21 = mul64x64(r20, r21, x1, x1)
r20, r21 = mul64x64(r20, r21, x3_38, x4)
r20, r21 := madd64(0, 0, x0_2, x2)
r20, r21 = madd64(r20, r21, x1, x1)
r20, r21 = madd64(r20, r21, x3_38, x4)
// r3 = x0*2*x3 + x1*2*x2 + x4*19*x4
r30, r31 := mul64x64(0, 0, x0_2, x3)
r30, r31 = mul64x64(r30, r31, x1_2, x2)
r30, r31 = mul64x64(r30, r31, x4_19, x4)
r30, r31 := madd64(0, 0, x0_2, x3)
r30, r31 = madd64(r30, r31, x1_2, x2)
r30, r31 = madd64(r30, r31, x4_19, x4)
// r4 = x0*2*x4 + x1*2*x3 + x2*x2
r40, r41 := mul64x64(0, 0, x0_2, x4)
r40, r41 = mul64x64(r40, r41, x1_2, x3)
r40, r41 = mul64x64(r40, r41, x2, x2)
r40, r41 := madd64(0, 0, x0_2, x4)
r40, r41 = madd64(r40, r41, x1_2, x3)
r40, r41 = madd64(r40, r41, x2, x2)
// Same reduction

View File

@ -9,60 +9,42 @@ import (
"crypto/rand"
"io"
"testing"
"unsafe"
)
func TestMul64to128(t *testing.T) {
a := uint64(5)
b := uint64(5)
r0, r1 := mul64x64(0, 0, a, b)
r0, r1 := madd64(0, 0, a, b)
if r0 != 0x19 || r1 != 0 {
t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r0, r1)
}
a = uint64(18014398509481983) // 2^54 - 1
b = uint64(18014398509481983) // 2^54 - 1
r0, r1 = mul64x64(0, 0, a, b)
r0, r1 = madd64(0, 0, a, b)
if r0 != 0xff80000000000001 || r1 != 0xfffffffffff {
t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r0, r1)
}
a = uint64(1125899906842661)
b = uint64(2097155)
r0, r1 = mul64x64(0, 0, a, b)
r0, r1 = mul64x64(r0, r1, a, b)
r0, r1 = mul64x64(r0, r1, a, b)
r0, r1 = mul64x64(r0, r1, a, b)
r0, r1 = mul64x64(r0, r1, a, b)
r0, r1 = madd64(0, 0, a, b)
r0, r1 = madd64(r0, r1, a, b)
r0, r1 = madd64(r0, r1, a, b)
r0, r1 = madd64(r0, r1, a, b)
r0, r1 = madd64(r0, r1, a, b)
if r0 != 16888498990613035 || r1 != 640 {
t.Errorf("wrong answer: %d + %d*(2**64)", r0, r1)
}
}
func BenchmarkWideMultInline(t *testing.B) {
var r0, r1, ol, oh uint64
a := uint64(18014398509481983) // 2^54 - 1
b := uint64(18014398509481983) // 2^54 - 1
for i := 0; i < t.N; i++ {
t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32)
t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF)
ol = (a * b) + r0
cmp := ol < r0
oh = r1 + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp)))
r1 = oh
r0 = ol
}
}
func BenchmarkWideMultCall(t *testing.B) {
var r0, r1 uint64
a := uint64(18014398509481983)
b := uint64(18014398509481983)
for i := 0; i < t.N; i++ {
r0, r1 = mul64x64(r0, r1, a, b)
r0, r1 = madd64(r0, r1, a, b)
}
}

View File

@ -0,0 +1,18 @@
// Copyright (c) 2019 George Tankersley. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.12
package radix51
import "math/bits"
// madd64 multiples two 64-bit numbers and adds them to a split 128-bit accumulator.
func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
oh, ol = bits.Mul64(a, b)
var c uint64
ol, c = bits.Add64(ol, lo, 0)
oh, _ = bits.Add64(oh, hi, c)
return
}

View File

@ -2,13 +2,14 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.12
package radix51
import "unsafe"
// mul64x64 multiples two 64-bit numbers and adds them to two accumulators.
// This function is written to ensure it inlines. I am so sorry.
func mul64x64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
// madd64 multiples two 64-bit numbers and adds them to a split 128-bit accumulator.
func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32)
t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF)
ol = (a * b) + lo