internal/scalar: add scalar inversion

This commit is contained in:
Henry de Valence 2019-05-15 11:33:43 -07:00 committed by Filippo Valsorda
parent 2d09ffd636
commit 24394cb387
2 changed files with 126 additions and 1 deletions

View File

@ -968,7 +968,7 @@ func (s *Scalar) SignedRadix16() [64]int8 {
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(s[i] & 15)
digits[2*i+1] = int8((s[i]>>4) & 15)
digits[2*i+1] = int8((s[i] >> 4) & 15)
}
// Recenter coefficients:
@ -980,3 +980,114 @@ func (s *Scalar) SignedRadix16() [64]int8 {
return digits
}
// Given k > 0, set s = s**(2*i).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Mul(s, s)
}
}
// Inv sets s to the inverse of a nonzero scalar v and returns s.
func (s *Scalar) Inv(t *Scalar) *Scalar {
// Uses a hardcoded sliding window of width 4.
var table [8]Scalar
var tt Scalar
tt.Mul(t, t)
table[0] = *t
for i := 0; i < 7; i++ {
table[i+1].Mul(&table[i], &tt)
}
// Now table = [t**1, t**3, t**7, t**11, t**13, t**15]
// so t**k = t[k/2] for odd k
// To compute the sliding window digits, use the following Sage script:
// sage: import itertools
// sage: def sliding_window(w,k):
// ....: digits = []
// ....: while k > 0:
// ....: if k % 2 == 1:
// ....: kmod = k % (2**w)
// ....: digits.append(kmod)
// ....: k = k - kmod
// ....: else:
// ....: digits.append(0)
// ....: k = k // 2
// ....: return digits
// Now we can compute s roughly as follows:
// sage: s = 1
// sage: for coeff in reversed(sliding_window(4,l-2)):
// ....: s = s*s
// ....: if coeff > 0 :
// ....: s = s*t**coeff
// This works on one bit at a time, with many runs of zeros.
// The digits can be collapsed into [(count, coeff)] as follows:
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
// Entries of the form (k, 0) turn into pow2k(k)
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
*s = table[1/2]
s.pow2k(127 + 1)
s.Mul(s, &table[1/2])
s.pow2k(4 + 1)
s.Mul(s, &table[9/2])
s.pow2k(3 + 1)
s.Mul(s, &table[11/2])
s.pow2k(3 + 1)
s.Mul(s, &table[13/2])
s.pow2k(3 + 1)
s.Mul(s, &table[15/2])
s.pow2k(4 + 1)
s.Mul(s, &table[7/2])
s.pow2k(4 + 1)
s.Mul(s, &table[15/2])
s.pow2k(3 + 1)
s.Mul(s, &table[5/2])
s.pow2k(3 + 1)
s.Mul(s, &table[1/2])
s.pow2k(4 + 1)
s.Mul(s, &table[15/2])
s.pow2k(4 + 1)
s.Mul(s, &table[15/2])
s.pow2k(4 + 1)
s.Mul(s, &table[7/2])
s.pow2k(3 + 1)
s.Mul(s, &table[3/2])
s.pow2k(4 + 1)
s.Mul(s, &table[11/2])
s.pow2k(5 + 1)
s.Mul(s, &table[11/2])
s.pow2k(9 + 1)
s.Mul(s, &table[9/2])
s.pow2k(3 + 1)
s.Mul(s, &table[3/2])
s.pow2k(4 + 1)
s.Mul(s, &table[3/2])
s.pow2k(4 + 1)
s.Mul(s, &table[3/2])
s.pow2k(4 + 1)
s.Mul(s, &table[9/2])
s.pow2k(3 + 1)
s.Mul(s, &table[7/2])
s.pow2k(3 + 1)
s.Mul(s, &table[3/2])
s.pow2k(3 + 1)
s.Mul(s, &table[13/2])
s.pow2k(3 + 1)
s.Mul(s, &table[7/2])
s.pow2k(4 + 1)
s.Mul(s, &table[9/2])
s.pow2k(3 + 1)
s.Mul(s, &table[15/2])
s.pow2k(4 + 1)
s.Mul(s, &table[11/2])
return s
}

View File

@ -82,3 +82,17 @@ func TestNonAdjacentForm(t *testing.T) {
}
}
}
func TestInvert(t *testing.T) {
invertWorks := func(x Scalar) bool {
var xInv, check Scalar
xInv.Inv(&x)
check.Mul(&x, &xInv)
return check.Equal(&scOne) == 1
}
if err := quick.Check(invertWorks, quickCheckConfig); err != nil {
t.Error(err)
}
}