diff --git a/internal/radix51/fe.go b/internal/radix51/fe.go index 65ac1b5..f9dd35b 100644 --- a/internal/radix51/fe.go +++ b/internal/radix51/fe.go @@ -323,6 +323,26 @@ func (v *FieldElement) Select(a, b *FieldElement, cond int) *FieldElement { return v } +// CondSwap swaps a and b if cond == 1 or leaves them unchanged if cond == 0. +func CondSwap(a, b *FieldElement, cond int) { + m := uint64(cond) * mask64Bits + t := m & (a[0] ^ b[0]) + a[0] ^= t + b[0] ^= t + t = m & (a[1] ^ b[1]) + a[1] ^= t + b[1] ^= t + t = m & (a[2] ^ b[2]) + a[2] ^= t + b[2] ^= t + t = m & (a[3] ^ b[3]) + a[3] ^= t + b[3] ^= t + t = m & (a[4] ^ b[4]) + a[4] ^= t + b[4] ^= t +} + // CondNeg sets v to -u if cond == 1, and to u if cond == 0. func (v *FieldElement) CondNeg(u *FieldElement, cond int) *FieldElement { tmp := new(FieldElement).Neg(u) diff --git a/internal/radix51/fe_test.go b/internal/radix51/fe_test.go index 332b82d..1b0ff69 100644 --- a/internal/radix51/fe_test.go +++ b/internal/radix51/fe_test.go @@ -335,3 +335,29 @@ func TestInvert(t *testing.T) { t.Errorf("random inversion identity failed, got: %x for field element %x", r, x) } } + +func TestSelectSwap(t *testing.T) { + a := FieldElement([5]uint64{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}) + b := FieldElement([5]uint64{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}) + + var c, d FieldElement + + c.Select(&a, &b, 1) + d.Select(&a, &b, 0) + + if (c.Equal(&a) != 1) || (d.Equal(&b) != 1) { + t.Errorf("Select failed") + } + + CondSwap(&c, &d, 0) + + if (c.Equal(&a) != 1) || (d.Equal(&b) != 1) { + t.Errorf("Swap failed") + } + + CondSwap(&c, &d, 1) + + if (c.Equal(&b) != 1) || (d.Equal(&a) != 1) { + t.Errorf("Swap failed") + } +}