diff --git a/internal/radix51/fe.go b/internal/radix51/fe.go index c07b9b3..43515e2 100644 --- a/internal/radix51/fe.go +++ b/internal/radix51/fe.go @@ -263,10 +263,10 @@ func (v *FieldElement) FromBytes(x []byte) *FieldElement { return v } -func (v *FieldElement) ToBytes(r []byte) { - if len(r) != 32 { - panic("invalid input size") - } +// AppendBytes appends a 32 bytes little-endian encoding of v to b. +func (v *FieldElement) AppendBytes(b []byte) []byte { + res, r := sliceForAppend(b, 32) + t := new(FieldElement).Reduce(v) r[0] = byte(t[0] & 0xff) @@ -309,6 +309,22 @@ func (v *FieldElement) ToBytes(r []byte) { r[29] = byte((t[4] >> 28) & 0xff) r[30] = byte((t[4] >> 36) & 0xff) r[31] = byte((t[4] >> 44)) + + return res +} + +// sliceForAppend extends the input slice by n bytes. head is the full extended +// slice, while tail is the appended part. If the original slice has sufficient +// capacity no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return } func (v *FieldElement) FromBig(num *big.Int) *FieldElement { @@ -333,8 +349,7 @@ func (v *FieldElement) FromBig(num *big.Int) *FieldElement { } func (v *FieldElement) ToBig() *big.Int { - var buf [32]byte - v.ToBytes(buf[:]) // does a reduction + buf := v.AppendBytes(nil) numWords := 256 / bits.UintSize words := make([]big.Word, numWords) @@ -360,8 +375,8 @@ func (v *FieldElement) ToBig() *big.Int { // Equal returns 1 if v and u are equal, and 0 otherwise. func (v *FieldElement) Equal(u *FieldElement) int { var sa, sv [32]byte - u.ToBytes(sa[:]) - v.ToBytes(sv[:]) + u.AppendBytes(sa[:0]) + v.AppendBytes(sv[:0]) return subtle.ConstantTimeCompare(sa[:], sv[:]) } @@ -388,7 +403,7 @@ func (v *FieldElement) CondNeg(u *FieldElement, cond int) *FieldElement { // IsNegative returns 1 if v is negative, and 0 otherwise. func (v *FieldElement) IsNegative() int { var b [32]byte - v.ToBytes(b[:]) + v.AppendBytes(b[:0]) return int(b[0] & 1) } diff --git a/internal/radix51/fe_test.go b/internal/radix51/fe_test.go index 1cbbfb7..2366573 100644 --- a/internal/radix51/fe_test.go +++ b/internal/radix51/fe_test.go @@ -124,7 +124,7 @@ func TestFeFromBytesRoundTrip(t *testing.T) { 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} fe.FromBytes(in[:]) - fe.ToBytes(out[:]) + fe.AppendBytes(out[:0]) if !bytes.Equal(in[:], out[:]) { t.Error("Bytes<>FE doesn't roundtrip") @@ -137,7 +137,7 @@ func TestFeFromBytesRoundTrip(t *testing.T) { fe[3] = 0x5e8fca9e0881c fe[4] = 0x5c490f087d796 - fe.ToBytes(out[:]) + fe.AppendBytes(out[:0]) r.FromBytes(out[:]) for i := 0; i < len(fe); i++ {