internal/radix51: implement (*FieldElement).Mul32

This pure Go implementation of Mul32 is more than twice as fast as the
assembly Mul implementation, and four times faster than the pure Go Mul.

Mul32          7.91ns ± 1%
Mul            18.6ns ± 1%
Mul [purego]   33.4ns ± 0%

Before Go 1.13, where we can't use math/bits because the fallbacks might
not be constant time, Mul32 is a little slower, but not nearly as much
as the pure Go Mul.

Mul32          9.74ns ± 0%
Mul [purego]   75.4ns ± 1%
This commit is contained in:
Filippo Valsorda 2019-12-01 00:32:36 -04:00 committed by George Tankersley
parent 6db84dcfdc
commit a679c261e4
5 changed files with 89 additions and 3 deletions

View File

@ -29,3 +29,12 @@ func BenchmarkMul(b *testing.B) {
x.Mul(&x, &y)
}
}
func BenchmarkMul32(b *testing.B) {
var x radix51.FieldElement
x.One()
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Mul32(&x, 0b10101010_10101010_10101010_10101010)
}
}

View File

@ -373,3 +373,20 @@ func (v *FieldElement) Square(x *FieldElement) *FieldElement {
feSquare(v, x)
return v
}
// Mul32 sets v = x * y and returns v.
func (v *FieldElement) Mul32(x *FieldElement, y uint32) *FieldElement {
x0lo, x0hi := mul51(x[0], y)
x1lo, x1hi := mul51(x[1], y)
x2lo, x2hi := mul51(x[2], y)
x3lo, x3hi := mul51(x[3], y)
x4lo, x4hi := mul51(x[4], y)
v[0] = x0lo + 19*x4hi // carried over per the reduction identity
v[1] = x1lo + x0hi
v[2] = x2lo + x1hi
v[3] = x3lo + x2hi
v[4] = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}

View File

@ -144,8 +144,9 @@ func TestMul64to128(t *testing.T) {
}
}
var r0, r1 uint64
func BenchmarkWideMultCall(t *testing.B) {
var r0, r1 uint64
a := uint64(18014398509481983)
b := uint64(18014398509481983)
@ -361,3 +362,34 @@ func TestSelectSwap(t *testing.T) {
t.Errorf("Swap failed")
}
}
func TestMul32(t *testing.T) {
isAlmostInBounds := func(x *FieldElement) bool {
return bits.Len64(x[0]) <= 52 &&
bits.Len64(x[1]) <= 52 &&
bits.Len64(x[2]) <= 52 &&
bits.Len64(x[3]) <= 52 &&
bits.Len64(x[4]) <= 52
}
mul32EquivalentToMul := func(x FieldElement, y uint32) bool {
t1 := new(FieldElement)
for i := 0; i < 100; i++ {
t1.Mul32(&x, y)
}
ty := new(FieldElement)
ty[0] = uint64(y)
t2 := new(FieldElement)
for i := 0; i < 100; i++ {
t2.Mul(&x, ty)
}
return t1.Equal(t2) == 1 && isAlmostInBounds(t1) && isInBounds(t2)
}
if err := quick.Check(mul32EquivalentToMul, quickCheckConfig); err != nil {
t.Error(err)
}
}

View File

@ -1,4 +1,5 @@
// Copyright (c) 2019 George Tankersley. All rights reserved.
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -8,7 +9,8 @@ package radix51
import "math/bits"
// madd64 multiples two 64-bit numbers and adds them to a split 128-bit accumulator.
// madd64 returns ol + oh * 2⁶⁴ = lo + hi * 2⁶⁴ + a * b. That is, it multiplies
// a and b, and adds the result to the split uint128 [lo,hi].
func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
oh, ol = bits.Mul64(a, b)
var c uint64
@ -16,3 +18,11 @@ func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
oh, _ = bits.Add64(oh, hi, c)
return
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}

View File

@ -1,4 +1,5 @@
// Copyright (c) 2017 George Tankersley. All rights reserved.
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -8,7 +9,8 @@ package radix51
import "unsafe"
// madd64 multiplies two 64-bit numbers and adds them to a split 128-bit accumulator.
// madd64 returns ol + oh * 2⁶⁴ = lo + hi * 2⁶⁴ + a * b. That is, it multiplies
// a and b, and adds the result to the split uint128 [lo,hi].
func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32)
t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF)
@ -17,3 +19,19 @@ func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) {
oh = hi + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp)))
return
}
const mask32 = 1<<32 - 1
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
w0 := (a & mask32) * uint64(b)
t := (a>>32)*uint64(b) + w0>>32
w1 := t & mask32
w2 := t >> 32
mh := w2 + w1>>32
ml := a * uint64(b)
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}