diff --git a/internal/radix51/alias_test.go b/internal/radix51/alias_test.go index 54089b9..7e2f84d 100644 --- a/internal/radix51/alias_test.go +++ b/internal/radix51/alias_test.go @@ -14,7 +14,7 @@ func checkAliasingOneArg(f func(v, x *FieldElement) *FieldElement) func(v, x Fie x1, v1 := x, x // Calculate a reference f(x) without aliasing. - if out := f(&v, &x); out != &v { + if out := f(&v, &x); out != &v && isInBounds(out) { return false } @@ -33,7 +33,7 @@ func checkAliasingTwoArgs(f func(v, x, y *FieldElement) *FieldElement) func(v, x x1, y1, v1 := x, y, FieldElement{} // Calculate a reference f(x, y) without aliasing. - if out := f(&v, &x, &y); out != &v { + if out := f(&v, &x, &y); out != &v && isInBounds(out) { return false } diff --git a/internal/radix51/fe_test.go b/internal/radix51/fe_test.go index 7c8303f..c9c13e6 100644 --- a/internal/radix51/fe_test.go +++ b/internal/radix51/fe_test.go @@ -10,6 +10,7 @@ import ( "crypto/rand" "io" "math/big" + "math/bits" mathrand "math/rand" "reflect" "testing" @@ -38,6 +39,16 @@ func (x FieldElement) Generate(rand *mathrand.Rand, size int) reflect.Value { return reflect.ValueOf(generateFieldElement(rand)) } +// isInBounds returns whether the element is within the expected bit size bounds +// after a light reduction. +func isInBounds(x *FieldElement) bool { + return bits.Len64(x[0]) <= 52 && + bits.Len64(x[1]) <= 51 && + bits.Len64(x[2]) <= 51 && + bits.Len64(x[3]) <= 51 && + bits.Len64(x[4]) <= 51 +} + func TestMulDistributesOverAdd(t *testing.T) { mulDistributesOverAdd := func(x, y, z FieldElement) bool { // Compute t1 = (x+y)*z @@ -52,7 +63,7 @@ func TestMulDistributesOverAdd(t *testing.T) { t3.Mul(&y, &z) t2.Add(t2, t3) - return t1.Equal(t2) == 1 + return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) } if err := quick.Check(mulDistributesOverAdd, quickCheckConfig); err != nil { @@ -112,7 +123,7 @@ func TestFromBytesRoundTrip(t *testing.T) { // not good. We should have a weird generator that aims for edge cases, // and we'll know it works when this test breaks. - return bytes.Equal(in[:], out[:]) + return bytes.Equal(in[:], out[:]) && isInBounds(&fe) } if err := quick.Check(f1, nil); err != nil { t.Errorf("failed bytes->FE->bytes round-trip: %v", err) @@ -179,7 +190,7 @@ func TestBytesBigEquivalence(t *testing.T) { buf := make([]byte, 32) // pad with zeroes copy(buf, swapEndianness(fe1.ToBig().Bytes())) - return bytes.Equal(out[:], buf) + return bytes.Equal(out[:], buf) && isInBounds(&fe) && isInBounds(&fe1) } if err := quick.Check(f1, nil); err != nil { t.Error(err)