mirror of https://github.com/gtank/ristretto255
radix51: use go1.12 intrinsics for 128-bit multiplications
This commit is contained in:
parent
6b0cdf071a
commit
4ac00273fd
|
@ -34,39 +34,39 @@ func (v *FieldElement) Mul(x, y *FieldElement) *FieldElement {
|
|||
x4_19 := x4 * 19
|
||||
|
||||
// calculate r0 = x0*y0 + 19*(x1*y4 + x2*y3 + x3*y2 + x4*y1)
|
||||
r00, r01 := mul64x64(0, 0, x0, y0)
|
||||
r00, r01 = mul64x64(r00, r01, x1_19, y4)
|
||||
r00, r01 = mul64x64(r00, r01, x2_19, y3)
|
||||
r00, r01 = mul64x64(r00, r01, x3_19, y2)
|
||||
r00, r01 = mul64x64(r00, r01, x4_19, y1)
|
||||
r00, r01 := madd64(0, 0, x0, y0)
|
||||
r00, r01 = madd64(r00, r01, x1_19, y4)
|
||||
r00, r01 = madd64(r00, r01, x2_19, y3)
|
||||
r00, r01 = madd64(r00, r01, x3_19, y2)
|
||||
r00, r01 = madd64(r00, r01, x4_19, y1)
|
||||
|
||||
// calculate r1 = x0*y1 + x1*y0 + 19*(x2*y4 + x3*y3 + x4*y2)
|
||||
r10, r11 := mul64x64(0, 0, x0, y1)
|
||||
r10, r11 = mul64x64(r10, r11, x1, y0)
|
||||
r10, r11 = mul64x64(r10, r11, x2_19, y4)
|
||||
r10, r11 = mul64x64(r10, r11, x3_19, y3)
|
||||
r10, r11 = mul64x64(r10, r11, x4_19, y2)
|
||||
r10, r11 := madd64(0, 0, x0, y1)
|
||||
r10, r11 = madd64(r10, r11, x1, y0)
|
||||
r10, r11 = madd64(r10, r11, x2_19, y4)
|
||||
r10, r11 = madd64(r10, r11, x3_19, y3)
|
||||
r10, r11 = madd64(r10, r11, x4_19, y2)
|
||||
|
||||
// calculate r2 = x0*y2 + x1*y1 + x2*y0 + 19*(x3*y4 + x4*y3)
|
||||
r20, r21 := mul64x64(0, 0, x0, y2)
|
||||
r20, r21 = mul64x64(r20, r21, x1, y1)
|
||||
r20, r21 = mul64x64(r20, r21, x2, y0)
|
||||
r20, r21 = mul64x64(r20, r21, x3_19, y4)
|
||||
r20, r21 = mul64x64(r20, r21, x4_19, y3)
|
||||
r20, r21 := madd64(0, 0, x0, y2)
|
||||
r20, r21 = madd64(r20, r21, x1, y1)
|
||||
r20, r21 = madd64(r20, r21, x2, y0)
|
||||
r20, r21 = madd64(r20, r21, x3_19, y4)
|
||||
r20, r21 = madd64(r20, r21, x4_19, y3)
|
||||
|
||||
// calculate r3 = x0*y3 + x1*y2 + x2*y1 + x3*y0 + 19*x4*y4
|
||||
r30, r31 := mul64x64(0, 0, x0, y3)
|
||||
r30, r31 = mul64x64(r30, r31, x1, y2)
|
||||
r30, r31 = mul64x64(r30, r31, x2, y1)
|
||||
r30, r31 = mul64x64(r30, r31, x3, y0)
|
||||
r30, r31 = mul64x64(r30, r31, x4_19, y4)
|
||||
r30, r31 := madd64(0, 0, x0, y3)
|
||||
r30, r31 = madd64(r30, r31, x1, y2)
|
||||
r30, r31 = madd64(r30, r31, x2, y1)
|
||||
r30, r31 = madd64(r30, r31, x3, y0)
|
||||
r30, r31 = madd64(r30, r31, x4_19, y4)
|
||||
|
||||
// calculate r4 = x0*y4 + x1*y3 + x2*y2 + x3*y1 + x4*y0
|
||||
r40, r41 := mul64x64(0, 0, x0, y4)
|
||||
r40, r41 = mul64x64(r40, r41, x1, y3)
|
||||
r40, r41 = mul64x64(r40, r41, x2, y2)
|
||||
r40, r41 = mul64x64(r40, r41, x3, y1)
|
||||
r40, r41 = mul64x64(r40, r41, x4, y0)
|
||||
r40, r41 := madd64(0, 0, x0, y4)
|
||||
r40, r41 = madd64(r40, r41, x1, y3)
|
||||
r40, r41 = madd64(r40, r41, x2, y2)
|
||||
r40, r41 = madd64(r40, r41, x3, y1)
|
||||
r40, r41 = madd64(r40, r41, x4, y0)
|
||||
|
||||
// After the multiplication we need to reduce (carry) the 5 coefficients to
|
||||
// obtain a result with coefficients that are at most slightly larger than
|
||||
|
|
|
@ -31,29 +31,29 @@ func (v *FieldElement) Square(x *FieldElement) *FieldElement {
|
|||
x4_19 := x4 * 19
|
||||
|
||||
// r0 = x0*x0 + x1*38*x4 + x2*38*x3
|
||||
r00, r01 := mul64x64(0, 0, x0, x0)
|
||||
r00, r01 = mul64x64(r00, r01, x1_38, x4)
|
||||
r00, r01 = mul64x64(r00, r01, x2_38, x3)
|
||||
r00, r01 := madd64(0, 0, x0, x0)
|
||||
r00, r01 = madd64(r00, r01, x1_38, x4)
|
||||
r00, r01 = madd64(r00, r01, x2_38, x3)
|
||||
|
||||
// r1 = x0*2*x1 + x2*38*x4 + x3*19*x3
|
||||
r10, r11 := mul64x64(0, 0, x0_2, x1)
|
||||
r10, r11 = mul64x64(r10, r11, x2_38, x4)
|
||||
r10, r11 = mul64x64(r10, r11, x3_19, x3)
|
||||
r10, r11 := madd64(0, 0, x0_2, x1)
|
||||
r10, r11 = madd64(r10, r11, x2_38, x4)
|
||||
r10, r11 = madd64(r10, r11, x3_19, x3)
|
||||
|
||||
// r2 = x0*2*x2 + x1*x1 + x3*38*x4
|
||||
r20, r21 := mul64x64(0, 0, x0_2, x2)
|
||||
r20, r21 = mul64x64(r20, r21, x1, x1)
|
||||
r20, r21 = mul64x64(r20, r21, x3_38, x4)
|
||||
r20, r21 := madd64(0, 0, x0_2, x2)
|
||||
r20, r21 = madd64(r20, r21, x1, x1)
|
||||
r20, r21 = madd64(r20, r21, x3_38, x4)
|
||||
|
||||
// r3 = x0*2*x3 + x1*2*x2 + x4*19*x4
|
||||
r30, r31 := mul64x64(0, 0, x0_2, x3)
|
||||
r30, r31 = mul64x64(r30, r31, x1_2, x2)
|
||||
r30, r31 = mul64x64(r30, r31, x4_19, x4)
|
||||
r30, r31 := madd64(0, 0, x0_2, x3)
|
||||
r30, r31 = madd64(r30, r31, x1_2, x2)
|
||||
r30, r31 = madd64(r30, r31, x4_19, x4)
|
||||
|
||||
// r4 = x0*2*x4 + x1*2*x3 + x2*x2
|
||||
r40, r41 := mul64x64(0, 0, x0_2, x4)
|
||||
r40, r41 = mul64x64(r40, r41, x1_2, x3)
|
||||
r40, r41 = mul64x64(r40, r41, x2, x2)
|
||||
r40, r41 := madd64(0, 0, x0_2, x4)
|
||||
r40, r41 = madd64(r40, r41, x1_2, x3)
|
||||
r40, r41 = madd64(r40, r41, x2, x2)
|
||||
|
||||
// Same reduction
|
||||
|
||||
|
|
|
@ -9,60 +9,42 @@ import (
|
|||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func TestMul64to128(t *testing.T) {
|
||||
a := uint64(5)
|
||||
b := uint64(5)
|
||||
r0, r1 := mul64x64(0, 0, a, b)
|
||||
r0, r1 := madd64(0, 0, a, b)
|
||||
if r0 != 0x19 || r1 != 0 {
|
||||
t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r0, r1)
|
||||
}
|
||||
|
||||
a = uint64(18014398509481983) // 2^54 - 1
|
||||
b = uint64(18014398509481983) // 2^54 - 1
|
||||
r0, r1 = mul64x64(0, 0, a, b)
|
||||
r0, r1 = madd64(0, 0, a, b)
|
||||
if r0 != 0xff80000000000001 || r1 != 0xfffffffffff {
|
||||
t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r0, r1)
|
||||
}
|
||||
|
||||
a = uint64(1125899906842661)
|
||||
b = uint64(2097155)
|
||||
r0, r1 = mul64x64(0, 0, a, b)
|
||||
r0, r1 = mul64x64(r0, r1, a, b)
|
||||
r0, r1 = mul64x64(r0, r1, a, b)
|
||||
r0, r1 = mul64x64(r0, r1, a, b)
|
||||
r0, r1 = mul64x64(r0, r1, a, b)
|
||||
r0, r1 = madd64(0, 0, a, b)
|
||||
r0, r1 = madd64(r0, r1, a, b)
|
||||
r0, r1 = madd64(r0, r1, a, b)
|
||||
r0, r1 = madd64(r0, r1, a, b)
|
||||
r0, r1 = madd64(r0, r1, a, b)
|
||||
if r0 != 16888498990613035 || r1 != 640 {
|
||||
t.Errorf("wrong answer: %d + %d*(2**64)", r0, r1)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWideMultInline(t *testing.B) {
|
||||
var r0, r1, ol, oh uint64
|
||||
a := uint64(18014398509481983) // 2^54 - 1
|
||||
b := uint64(18014398509481983) // 2^54 - 1
|
||||
|
||||
for i := 0; i < t.N; i++ {
|
||||
t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32)
|
||||
t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF)
|
||||
ol = (a * b) + r0
|
||||
cmp := ol < r0
|
||||
oh = r1 + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp)))
|
||||
|
||||
r1 = oh
|
||||
r0 = ol
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWideMultCall(t *testing.B) {
|
||||
var r0, r1 uint64
|
||||
a := uint64(18014398509481983)
|
||||
b := uint64(18014398509481983)
|
||||
|
||||
for i := 0; i < t.N; i++ {
|
||||
r0, r1 = mul64x64(r0, r1, a, b)
|
||||
r0, r1 = madd64(r0, r1, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) 2019 George Tankersley. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.12
|
||||
|
||||
package radix51
|
||||
|
||||
import "math/bits"
|
||||
|
||||
// madd64 multiples two 64-bit numbers and adds them to a split 128-bit accumulator.
|
||||
func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
|
||||
oh, ol = bits.Mul64(a, b)
|
||||
var c uint64
|
||||
ol, c = bits.Add64(ol, lo, 0)
|
||||
oh, _ = bits.Add64(oh, hi, c)
|
||||
return
|
||||
}
|
|
@ -2,13 +2,14 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.12
|
||||
|
||||
package radix51
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// mul64x64 multiples two 64-bit numbers and adds them to two accumulators.
|
||||
// This function is written to ensure it inlines. I am so sorry.
|
||||
func mul64x64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
|
||||
// madd64 multiples two 64-bit numbers and adds them to a split 128-bit accumulator.
|
||||
func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
|
||||
t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32)
|
||||
t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF)
|
||||
ol = (a * b) + lo
|
Loading…
Reference in New Issue