diff --git a/internal/radix51/fe.go b/internal/radix51/fe.go index 755007c..1ee3bc6 100644 --- a/internal/radix51/fe.go +++ b/internal/radix51/fe.go @@ -215,7 +215,10 @@ func (v *FieldElement) Set(a *FieldElement) *FieldElement { return v } -func (v *FieldElement) FromBytes(x *[32]byte) *FieldElement { +func (v *FieldElement) FromBytes(x []byte) *FieldElement { + if len(x) != 32 { + panic("invalid input size") + } v[0] = uint64(x[0]) v[0] |= uint64(x[1]) << 8 v[0] |= uint64(x[2]) << 16 @@ -260,7 +263,10 @@ func (v *FieldElement) FromBytes(x *[32]byte) *FieldElement { return v } -func (v *FieldElement) ToBytes(r *[32]byte) { +func (v *FieldElement) ToBytes(r []byte) { + if len(r) != 32 { + panic("invalid input size") + } t := new(FieldElement).Reduce(v) r[0] = byte(t[0] & 0xff) @@ -323,12 +329,12 @@ func (v *FieldElement) FromBig(num *big.Int) *FieldElement { } } - return v.FromBytes(&buf) + return v.FromBytes(buf[:]) } func (v *FieldElement) ToBig() *big.Int { var buf [32]byte - v.ToBytes(&buf) // does a reduction + v.ToBytes(buf[:]) // does a reduction numWords := 256 / bits.UintSize words := make([]big.Word, numWords) @@ -354,8 +360,8 @@ func (v *FieldElement) ToBig() *big.Int { // Equal returns 1 if v and u are equal, and 0 otherwise. func (v *FieldElement) Equal(u *FieldElement) int { var sa, sv [32]byte - u.ToBytes(&sa) - v.ToBytes(&sv) + u.ToBytes(sa[:]) + v.ToBytes(sv[:]) return subtle.ConstantTimeCompare(sa[:], sv[:]) } @@ -381,7 +387,7 @@ func (v *FieldElement) CondNeg(u *FieldElement, cond int) *FieldElement { // IsNegative returns 1 if v is negative, and 0 otherwise. func (v *FieldElement) IsNegative() int { var b [32]byte - v.ToBytes(&b) + v.ToBytes(b[:]) return int(b[0] & 1) } diff --git a/internal/radix51/fe_test.go b/internal/radix51/fe_test.go index b77a50d..1cbbfb7 100644 --- a/internal/radix51/fe_test.go +++ b/internal/radix51/fe_test.go @@ -123,8 +123,8 @@ func TestFeFromBytesRoundTrip(t *testing.T) { in = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} - fe.FromBytes(&in) - fe.ToBytes(&out) + fe.FromBytes(in[:]) + fe.ToBytes(out[:]) if !bytes.Equal(in[:], out[:]) { t.Error("Bytes<>FE doesn't roundtrip") @@ -137,8 +137,8 @@ func TestFeFromBytesRoundTrip(t *testing.T) { fe[3] = 0x5e8fca9e0881c fe[4] = 0x5c490f087d796 - fe.ToBytes(&out) - r.FromBytes(&out) + fe.ToBytes(out[:]) + r.FromBytes(out[:]) for i := 0; i < len(fe); i++ { if r[i] != fe[i] { @@ -173,7 +173,7 @@ func TestSanity(t *testing.T) { if err != nil { t.Fatal(err) } - x.FromBytes(&bytes) + x.FromBytes(bytes[:]) x2.Mul(&x, &x) // FeMulGo(&x2Go, &x, &x) @@ -232,7 +232,7 @@ func TestFeInvert(t *testing.T) { if err != nil { t.Fatal(err) } - x.FromBytes(&bytes) + x.FromBytes(bytes[:]) xinv.Invert(&x) r.Mul(&x, &xinv)