From a679c261e4df42455c830458dee32512a3e08211 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Sun, 1 Dec 2019 00:32:36 -0400 Subject: [PATCH] internal/radix51: implement (*FieldElement).Mul32 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This pure Go implementation of Mul32 is more than twice as fast as the assembly Mul implementation, and four times faster than the pure Go Mul. Mul32 7.91ns ± 1% Mul 18.6ns ± 1% Mul [purego] 33.4ns ± 0% Before Go 1.13, where we can't use math/bits because the fallbacks might not be constant time, Mul32 is a little slower, but not nearly as much as the pure Go Mul. Mul32 9.74ns ± 0% Mul [purego] 75.4ns ± 1% --- internal/radix51/bench_test.go | 9 +++++++++ internal/radix51/fe.go | 17 +++++++++++++++++ internal/radix51/fe_test.go | 34 +++++++++++++++++++++++++++++++++- internal/radix51/mul_bits.go | 12 +++++++++++- internal/radix51/mul_compat.go | 20 +++++++++++++++++++- 5 files changed, 89 insertions(+), 3 deletions(-) diff --git a/internal/radix51/bench_test.go b/internal/radix51/bench_test.go index ce12323..07722ad 100644 --- a/internal/radix51/bench_test.go +++ b/internal/radix51/bench_test.go @@ -29,3 +29,12 @@ func BenchmarkMul(b *testing.B) { x.Mul(&x, &y) } } + +func BenchmarkMul32(b *testing.B) { + var x radix51.FieldElement + x.One() + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Mul32(&x, 0b10101010_10101010_10101010_10101010) + } +} diff --git a/internal/radix51/fe.go b/internal/radix51/fe.go index 09f36bc..090fca6 100644 --- a/internal/radix51/fe.go +++ b/internal/radix51/fe.go @@ -373,3 +373,20 @@ func (v *FieldElement) Square(x *FieldElement) *FieldElement { feSquare(v, x) return v } + +// Mul32 sets v = x * y and returns v. +func (v *FieldElement) Mul32(x *FieldElement, y uint32) *FieldElement { + x0lo, x0hi := mul51(x[0], y) + x1lo, x1hi := mul51(x[1], y) + x2lo, x2hi := mul51(x[2], y) + x3lo, x3hi := mul51(x[3], y) + x4lo, x4hi := mul51(x[4], y) + v[0] = x0lo + 19*x4hi // carried over per the reduction identity + v[1] = x1lo + x0hi + v[2] = x2lo + x1hi + v[3] = x3lo + x2hi + v[4] = x4lo + x3hi + // The hi portions are going to be only 32 bits, plus any previous excess, + // so we can skip the carry propagation. + return v +} diff --git a/internal/radix51/fe_test.go b/internal/radix51/fe_test.go index 2ffcb9d..5dbdedf 100644 --- a/internal/radix51/fe_test.go +++ b/internal/radix51/fe_test.go @@ -144,8 +144,9 @@ func TestMul64to128(t *testing.T) { } } +var r0, r1 uint64 + func BenchmarkWideMultCall(t *testing.B) { - var r0, r1 uint64 a := uint64(18014398509481983) b := uint64(18014398509481983) @@ -361,3 +362,34 @@ func TestSelectSwap(t *testing.T) { t.Errorf("Swap failed") } } + +func TestMul32(t *testing.T) { + isAlmostInBounds := func(x *FieldElement) bool { + return bits.Len64(x[0]) <= 52 && + bits.Len64(x[1]) <= 52 && + bits.Len64(x[2]) <= 52 && + bits.Len64(x[3]) <= 52 && + bits.Len64(x[4]) <= 52 + } + + mul32EquivalentToMul := func(x FieldElement, y uint32) bool { + t1 := new(FieldElement) + for i := 0; i < 100; i++ { + t1.Mul32(&x, y) + } + + ty := new(FieldElement) + ty[0] = uint64(y) + + t2 := new(FieldElement) + for i := 0; i < 100; i++ { + t2.Mul(&x, ty) + } + + return t1.Equal(t2) == 1 && isAlmostInBounds(t1) && isInBounds(t2) + } + + if err := quick.Check(mul32EquivalentToMul, quickCheckConfig); err != nil { + t.Error(err) + } +} diff --git a/internal/radix51/mul_bits.go b/internal/radix51/mul_bits.go index bf10590..89b867c 100644 --- a/internal/radix51/mul_bits.go +++ b/internal/radix51/mul_bits.go @@ -1,4 +1,5 @@ // Copyright (c) 2019 George Tankersley. All rights reserved. +// Copyright (c) 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -8,7 +9,8 @@ package radix51 import "math/bits" -// madd64 multiples two 64-bit numbers and adds them to a split 128-bit accumulator. +// madd64 returns ol + oh * 2⁶⁴ = lo + hi * 2⁶⁴ + a * b. That is, it multiplies +// a and b, and adds the result to the split uint128 [lo,hi]. func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) { oh, ol = bits.Mul64(a, b) var c uint64 @@ -16,3 +18,11 @@ func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) { oh, _ = bits.Add64(oh, hi, c) return } + +// mul51 returns lo + hi * 2⁵¹ = a * b. +func mul51(a uint64, b uint32) (lo uint64, hi uint64) { + mh, ml := bits.Mul64(a, uint64(b)) + lo = ml & maskLow51Bits + hi = (mh << 13) | (ml >> 51) + return +} diff --git a/internal/radix51/mul_compat.go b/internal/radix51/mul_compat.go index 59d5b75..4b9ac35 100644 --- a/internal/radix51/mul_compat.go +++ b/internal/radix51/mul_compat.go @@ -1,4 +1,5 @@ // Copyright (c) 2017 George Tankersley. All rights reserved. +// Copyright (c) 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -8,7 +9,8 @@ package radix51 import "unsafe" -// madd64 multiplies two 64-bit numbers and adds them to a split 128-bit accumulator. +// madd64 returns ol + oh * 2⁶⁴ = lo + hi * 2⁶⁴ + a * b. That is, it multiplies +// a and b, and adds the result to the split uint128 [lo,hi]. func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) { t1 := (a>>32)*(b&0xFFFFFFFF) + ((a & 0xFFFFFFFF) * (b & 0xFFFFFFFF) >> 32) t2 := (a&0xFFFFFFFF)*(b>>32) + (t1 & 0xFFFFFFFF) @@ -17,3 +19,19 @@ func madd64(lo, hi, a, b uint64) (ol uint64, oh uint64) { oh = hi + (a>>32)*(b>>32) + t1>>32 + t2>>32 + uint64(*(*byte)(unsafe.Pointer(&cmp))) return } + +const mask32 = 1<<32 - 1 + +// mul51 returns lo + hi * 2⁵¹ = a * b. +func mul51(a uint64, b uint32) (lo uint64, hi uint64) { + w0 := (a & mask32) * uint64(b) + t := (a>>32)*uint64(b) + w0>>32 + w1 := t & mask32 + w2 := t >> 32 + mh := w2 + w1>>32 + ml := a * uint64(b) + + lo = ml & maskLow51Bits + hi = (mh << 13) | (ml >> 51) + return +}