From 9332385a6f8aad35a9115b563c9547ed96a0aa86 Mon Sep 17 00:00:00 2001 From: Sarah Jamie Lewis Date: Mon, 7 Oct 2019 22:57:45 -0700 Subject: [PATCH] Switching to a different ristretto implementation --- primitives/bulletproofs/inner_product.go | 57 +++--- primitives/bulletproofs/inner_product_test.go | 51 +++-- primitives/bulletproofs/range_proof.go | 179 +++++++++++------- primitives/bulletproofs/range_proof_test.go | 1 - primitives/core/operations.go | 71 +++---- primitives/core/transcript.go | 29 ++- 6 files changed, 222 insertions(+), 166 deletions(-) diff --git a/primitives/bulletproofs/inner_product.go b/primitives/bulletproofs/inner_product.go index cded9c6..c507bf2 100644 --- a/primitives/bulletproofs/inner_product.go +++ b/primitives/bulletproofs/inner_product.go @@ -4,7 +4,7 @@ import ( "cwtch.im/tapir/primitives/core" "encoding/json" "git.openprivacy.ca/openprivacy/libricochet-go/log" - "github.com/bwesterb/go-ristretto" + ristretto "github.com/gtank/ristretto255" "strconv" ) @@ -12,12 +12,12 @@ import ( type InnerProductProof struct { L core.PointVector R core.PointVector - A core.Scalar - B core.Scalar + A *ristretto.Scalar + B *ristretto.Scalar } // ProveInnerProduct generates a proof for , the inner product of a and b -func ProveInnerProduct(a, b core.ScalarVector, u core.Point, P *ristretto.Point, G, H core.GeneratorVector, transcript *core.Transcript) InnerProductProof { +func ProveInnerProduct(a, b core.ScalarVector, u *ristretto.Element, P *ristretto.Element, G, H core.GeneratorVector, transcript *core.Transcript) InnerProductProof { n := len(a) transcript.AddToTranscript("n", []byte(strconv.Itoa(n))) Lvec := core.PointVector{} @@ -26,7 +26,7 @@ func ProveInnerProduct(a, b core.ScalarVector, u core.Point, P *ristretto.Point, transcript.AddToTranscript("G", Gbytes) Hbytes, _ := json.Marshal(H) transcript.AddToTranscript("H", Hbytes) - transcript.AddToTranscript("P'", P.Bytes()) + transcript.AddToTranscript("P'", []byte(P.String())) for n != 1 { np := n / 2 aL, aR := a[:np], a[np:] @@ -40,30 +40,31 @@ func ProveInnerProduct(a, b core.ScalarVector, u core.Point, P *ristretto.Point, L := core.MultiExp(append(aL.Join(bR), cL), append(GR.Join(HL), u)) R := core.MultiExp(append(aR.Join(bL), cR), append(GL.Join(HR), u)) - transcript.AddToTranscript("L", L.Bytes()) + transcript.AddToTranscript("L", []byte(L.String())) Lvec = append(Lvec, L) - transcript.AddToTranscript("R", R.Bytes()) + transcript.AddToTranscript("R",[]byte(R.String())) Rvec = append(Rvec, R) u := transcript.CommitToTranscriptScalar("u") - uinv := new(ristretto.Scalar).Inverse(u) + uinv := new(ristretto.Scalar) + uinv.Invert(u) for i := 0; i < len(aL); i++ { - aL_ := new(ristretto.Scalar).Mul(aL[i], u) - aL[i] = aL_.Add(aL_, new(ristretto.Scalar).Mul(aR[i], uinv)) - bL_ := new(ristretto.Scalar).Mul(bL[i], uinv) - bL[i] = bL_.Add(bL_, new(ristretto.Scalar).Mul(bR[i], u)) + aL_ := new(ristretto.Scalar).Multiply(aL[i], u) + aL[i] = new(ristretto.Scalar).Add(aL_, new(ristretto.Scalar).Multiply(aR[i], uinv)) + bL_ := new(ristretto.Scalar).Multiply(bL[i], uinv) + bL[i] = new(ristretto.Scalar).Add(bL_, new(ristretto.Scalar).Multiply(bR[i], u)) GL[i] = core.MultiExp(core.ScalarVector{uinv, u}, core.GeneratorVector{GL[i], GR[i]}) HL[i] = core.MultiExp(core.ScalarVector{u, uinv}, core.GeneratorVector{HL[i], HR[i]}) } - x2 := new(ristretto.Scalar).Square(u) - P_ := new(ristretto.Point).ScalarMult(L, x2) + x2 := new(ristretto.Scalar).Multiply(u,u) + P_ := new(ristretto.Element).ScalarMult(x2,L) P_.Add(P_, P) - P_.Add(P_, new(ristretto.Point).ScalarMult(R, new(ristretto.Scalar).Inverse(x2))) + P_.Add(P_, new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(x2),R)) P = P_ - transcript.AddToTranscript("P'", P.Bytes()) + transcript.AddToTranscript("P'", []byte(P.String())) a = aL b = bL @@ -80,7 +81,7 @@ func ProveInnerProduct(a, b core.ScalarVector, u core.Point, P *ristretto.Point, } // Verify checks the given inner product proof -func Verify(proof InnerProductProof, n int, u, P *ristretto.Point, G, H core.GeneratorVector, transcript *core.Transcript) bool { +func Verify(proof InnerProductProof, n int, u, P *ristretto.Element, G, H core.GeneratorVector, transcript *core.Transcript) bool { transcript.AddToTranscript("n", []byte(strconv.Itoa(n))) np := n / 2 @@ -88,27 +89,28 @@ func Verify(proof InnerProductProof, n int, u, P *ristretto.Point, G, H core.Gen transcript.AddToTranscript("G", Gbytes) Hbytes, _ := json.Marshal(H) transcript.AddToTranscript("H", Hbytes) - transcript.AddToTranscript("P'", P.Bytes()) + transcript.AddToTranscript("P'", []byte(P.String())) for i := range proof.L { GL, GR := G[:np], G[np:] HL, HR := H[:np], H[np:] - transcript.AddToTranscript("L", proof.L[i].Bytes()) - transcript.AddToTranscript("R", proof.R[i].Bytes()) + transcript.AddToTranscript("L", []byte(proof.L[i].String())) + transcript.AddToTranscript("R", []byte(proof.R[i].String())) x := transcript.CommitToTranscriptScalar("u") - xinv := new(ristretto.Scalar).Inverse(x) + xinv := new(ristretto.Scalar) + xinv.Invert(x) for j := 0; j < np; j++ { GL[j] = core.MultiExp(core.ScalarVector{xinv, x}, core.GeneratorVector{GL[j], GR[j]}) HL[j] = core.MultiExp(core.ScalarVector{x, xinv}, core.GeneratorVector{HL[j], HR[j]}) } - x2 := new(ristretto.Scalar).Square(x) - P_ := new(ristretto.Point).ScalarMult(proof.L[i], x2) + x2 := new(ristretto.Scalar).Multiply(x,x) + P_ := new(ristretto.Element).ScalarMult(x2,proof.L[i]) P_.Add(P_, P) - P_.Add(P_, new(ristretto.Point).ScalarMult(proof.R[i], new(ristretto.Scalar).Inverse(x2))) + P_.Add(P_, new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(x2),proof.R[i])) P = P_ - transcript.AddToTranscript("P'", P.Bytes()) + transcript.AddToTranscript("P'", []byte(P.String())) G = GL H = HL @@ -118,8 +120,9 @@ func Verify(proof InnerProductProof, n int, u, P *ristretto.Point, G, H core.Gen transcript.AddToTranscript("H", Hbytes) np = np / 2 } - c := new(ristretto.Scalar).Mul(proof.A, proof.B) + c := new(ristretto.Scalar) + c.Multiply(proof.A, proof.B) P_ := core.MultiExp(core.ScalarVector{proof.A, proof.B, c}, core.GeneratorVector{G[0], H[0], u}) log.Debugf("P:%v\nP':%v\n", P, P_) - return P.Equals(P_) + return P.Equal(P_) == 1 } diff --git a/primitives/bulletproofs/inner_product_test.go b/primitives/bulletproofs/inner_product_test.go index 6bcf558..d210fb3 100644 --- a/primitives/bulletproofs/inner_product_test.go +++ b/primitives/bulletproofs/inner_product_test.go @@ -3,12 +3,13 @@ package bulletproofs import ( "cwtch.im/tapir/primitives/core" "git.openprivacy.ca/openprivacy/libricochet-go/log" - "github.com/bwesterb/go-ristretto" + ristretto "github.com/gtank/ristretto255" "testing" ) + func assert(t *testing.T, expected *ristretto.Scalar, actual *ristretto.Scalar) { - if expected.Equals(actual) { + if expected.Equal(actual) == 1 { t.Logf("inner_product matched: %v", actual) } else { t.Fatalf("c should be %v instead: %v", expected, actual) @@ -16,8 +17,9 @@ func assert(t *testing.T, expected *ristretto.Scalar, actual *ristretto.Scalar) } func Test_inner_product(t *testing.T) { - one := new(ristretto.Scalar).SetOne() - zero := new(ristretto.Scalar).SetZero() + one := core.IdentityVector(1)[0] + zero := new(ristretto.Scalar) + zero.Zero() a := core.ScalarVector{one, zero, one, zero} b := core.ScalarVector{zero, one, zero, one} c := core.InnerProduct(a, b) @@ -26,35 +28,42 @@ func Test_inner_product(t *testing.T) { a = core.ScalarVector{one, one, one, zero} b = core.ScalarVector{one, one, zero, one} c = core.InnerProduct(a, b) - assert(t, new(ristretto.Scalar).Add(one, one), c) + + check := new(ristretto.Scalar) + check.Add(one,one) + + assert(t, check, c) } func TestProveInnerProduct(t *testing.T) { log.SetLevel(log.LevelDebug) - one := new(ristretto.Scalar).SetOne() - zero := new(ristretto.Scalar).SetZero() + one := core.IdentityVector(1)[0] + zero := new(ristretto.Scalar) + zero.Zero() a := core.ScalarVector{one, zero, one, one} b := core.ScalarVector{zero, one, one, one} - c := core.InnerProduct(a, b) - G := make(core.GeneratorVector, 4) - H := make(core.GeneratorVector, 4) - for i := 0; i < 4; i++ { - G[i] = new(ristretto.Point).Rand() - H[i] = new(ristretto.Point).Rand() - } - - u := new(ristretto.Point).Rand() - - P_ := core.MultiExp(append(a.Join(b), c), append(G.Join(H), u)) - proverTranscript := core.NewTranscript("test_innerproductproof") verifierTranscript := core.NewTranscript("test_innerproductproof") - proof := ProveInnerProduct(a, b, u, new(ristretto.Point).Set(P_), core.CopyVector(G), core.CopyVector(H), proverTranscript) - if Verify(proof, 4, u, new(ristretto.Point).Set(P_), core.CopyVector(G), core.CopyVector(H), verifierTranscript) { + G := proverTranscript.CommitToGenerators("G",4) + H := proverTranscript.CommitToGenerators("H",4) + u := proverTranscript.CommitToGenerator("u") + + verifierTranscript.CommitToGenerators("G",4) + verifierTranscript.CommitToGenerators("H",4) + verifierTranscript.CommitToGenerator("u") + + c := core.InnerProduct(a,b) + + P_ := core.MultiExp(append(a.Join(b), c), append(core.GeneratorVector(G).Join(core.GeneratorVector(H)), u)) + + + proof := ProveInnerProduct(a, b, u, new(ristretto.Element).Add(new(ristretto.Element).Zero(), P_), core.CopyVector(G), core.CopyVector(H), proverTranscript) + + if Verify(proof, 4, u, P_, core.CopyVector(G), core.CopyVector(H), verifierTranscript) { t.Logf("Inner Product Proof Passed!") } else { t.Logf("%v\n\n%v\n", proverTranscript.OutputTranscriptToAudit(), verifierTranscript.OutputTranscriptToAudit()) diff --git a/primitives/bulletproofs/range_proof.go b/primitives/bulletproofs/range_proof.go index 6bfd015..7f4e89e 100644 --- a/primitives/bulletproofs/range_proof.go +++ b/primitives/bulletproofs/range_proof.go @@ -1,32 +1,33 @@ package bulletproofs import ( + "crypto/rand" "cwtch.im/tapir/primitives/core" "git.openprivacy.ca/openprivacy/libricochet-go/log" - "github.com/bwesterb/go-ristretto" + ristretto "github.com/gtank/ristretto255" "math/big" ) // RangeProof encapsulates a proof that V = [0, max) where max is defined by the Setup function type RangeProof struct { - A *ristretto.Point - S *ristretto.Point - T1 *ristretto.Point - T2 *ristretto.Point + A *ristretto.Element + S *ristretto.Element + T1 *ristretto.Element + T2 *ristretto.Element TauX *ristretto.Scalar InnerProduct *ristretto.Scalar Mu *ristretto.Scalar IPP InnerProductProof - V *ristretto.Point + V *ristretto.Element } // CommitmentsParams encapsulates the commitment parameters for a given range proof type CommitmentsParams struct { G core.GeneratorVector H core.GeneratorVector - u *ristretto.Point - g *ristretto.Point - h *ristretto.Point + u *ristretto.Element + g *ristretto.Element + h *ristretto.Element max int } @@ -43,69 +44,92 @@ func Setup(max int, transcript *core.Transcript) (c CommitmentsParams) { return } +func Rand(seed string) *ristretto.Scalar { + t := core.NewTranscript(seed) + return t.CommitToTranscriptScalar("seed") +} + // GenerateRangeProof creates a valid rangeproof that value is within [0,max) under the given transcript -// It returns the rangeproof and a random scalar "gamma" that can be used to open V, the commitement to v vGgH +// It returns the rangeproof and a random scalar "gamma" that can be used to open V, the commitment to v vGgH func GenerateRangeProof(value int32, c CommitmentsParams, transcript *core.Transcript) (RangeProof, *ristretto.Scalar) { - one := new(ristretto.Scalar).SetOne() - two := new(ristretto.Scalar).Add(one, one) + one := core.IdentityVector(1)[0] + two := new(ristretto.Scalar) + two.Add(one, one) two_n := core.PowerVector(two, c.max) - gamma := new(ristretto.Scalar).Rand() + // Generate a prng to from this transcript and some external randomness + // We use this to generate the rest of our private scalars + // TODO: move to transcript + private := make([]byte,64) + rand.Read(private) + prngTranscript:= core.NewTranscript("private-transcript") + prngTranscript.AddToTranscript("randomness",private) + prng := prngTranscript.CommitToPRNG(transcript.OutputTranscriptToAudit()) + + gamma := prng.Next() aL := valueToVector(value, c.max) - aR := core.VectorAddScalar(aL, new(ristretto.Scalar).Neg(one)) - alpha := new(ristretto.Scalar).Rand() + aR := core.VectorAddScalar(aL, new(ristretto.Scalar).Negate(one)) + alpha := prng.Next() + + vs := new(ristretto.Scalar) + b := make([]byte,32) + copy(b,big.NewInt(int64(value)).Bytes()) + vs.Decode(b) + - vs := new(ristretto.Scalar).SetBigInt(big.NewInt(int64(value))) V := core.MultiExp(core.ScalarVector{gamma, vs}, core.GeneratorVector{c.h, c.g}) A := core.MultiExp(append(aL.Join(aR), alpha), append(c.G.Join(c.H), c.h)) + log.Debugf("vs: %v", vs) Sl := make(core.ScalarVector, c.max) Sr := make(core.ScalarVector, c.max) for i := 0; i < c.max; i++ { - Sl[i] = new(ristretto.Scalar).Rand() - Sr[i] = new(ristretto.Scalar).Rand() + Sl[i] = prng.Next() + Sr[i] = prng.Next() } - p := new(ristretto.Scalar).Rand() + p := prng.Next() S := core.MultiExp(append(Sl.Join(Sr), p), append(c.G.Join(c.H), c.h)) - transcript.AddToTranscript("A", A.Bytes()) - transcript.AddToTranscript("S", S.Bytes()) + transcript.AddToTranscript("A", []byte(A.String())) + transcript.AddToTranscript("S",[]byte(S.String())) y := transcript.CommitToTranscriptScalar("y") z := transcript.CommitToTranscriptScalar("z") y_n := core.PowerVector(y, c.max) - z2 := new(ristretto.Scalar).Square(z) + z2 := new(ristretto.Scalar).Multiply(z,z) - l0 := core.VectorAddScalar(aL, new(ristretto.Scalar).Neg(z)) + l0 := core.VectorAddScalar(aL, new(ristretto.Scalar).Negate(z)) //l1 == Sr r0 := core.EntrywiseSum(core.EntryWiseProduct(y_n, core.VectorAddScalar(aR, z)), core.VectorMulScalar(two_n, z2)) r1 := core.EntryWiseProduct(Sr, y_n) - t0 := new(ristretto.Scalar).Add(new(ristretto.Scalar).Mul(z2, vs), delta(y_n, z, c.max)) - t1 := new(ristretto.Scalar).Add(core.InnerProduct(Sl, r0), core.InnerProduct(l0, r1)) + t0 := new(ristretto.Scalar).Add(new(ristretto.Scalar).Multiply(z2, vs), delta(y_n, z, c.max)) + t1 := new(ristretto.Scalar) + t1.Add(core.InnerProduct(Sl, r0), core.InnerProduct(l0, r1)) t2 := core.InnerProduct(Sl, r1) - tau1 := new(ristretto.Scalar).Rand() - tau2 := new(ristretto.Scalar).Rand() + tau1 := prng.Next() + tau2 := prng.Next() T1 := core.MultiExp(core.ScalarVector{t1, tau1}, core.GeneratorVector{c.g, c.h}) T2 := core.MultiExp(core.ScalarVector{t2, tau2}, core.GeneratorVector{c.g, c.h}) - transcript.AddToTranscript("T1", T1.Bytes()) - transcript.AddToTranscript("T2", T2.Bytes()) + transcript.AddToTranscript("T1", []byte(T1.String())) + transcript.AddToTranscript("T2", []byte(T2.String())) x := transcript.CommitToTranscriptScalar("x") // T(X) = t0 + t1x + t2x - TX := new(ristretto.Scalar).Set(t0) - TX.Add(TX, new(ristretto.Scalar).Mul(t1, x)) - TX.Add(TX, new(ristretto.Scalar).Mul(t2, new(ristretto.Scalar).Square(x))) + TX := new(ristretto.Scalar) + TX.Add(new(ristretto.Scalar).Zero(), t0) + TX.Add(TX, new(ristretto.Scalar).Multiply(t1, x)) + TX.Add(TX, new(ristretto.Scalar).Multiply(t2, new(ristretto.Scalar).Multiply(x,x))) - l := core.EntrywiseSum(core.VectorAddScalar(aL, new(ristretto.Scalar).Neg(z)), core.VectorMulScalar(Sl, x)) + l := core.EntrywiseSum(core.VectorAddScalar(aL, new(ristretto.Scalar).Negate(z)), core.VectorMulScalar(Sl, x)) _r := core.EntrywiseSum(core.VectorAddScalar(aR, z), core.VectorMulScalar(Sr, x)) r := core.EntrywiseSum(core.EntryWiseProduct(y_n, _r), core.VectorMulScalar(two_n, z2)) @@ -113,103 +137,118 @@ func GenerateRangeProof(value int32, c CommitmentsParams, transcript *core.Trans log.Debugf("T(X) = %v", TX) log.Debugf("ipp = %v", iplr) - log.Debugf("equal: %v", TX.Equals(iplr)) + log.Debugf("equal: %v", TX.Equal(iplr) == 1) // generate h' H_ := make(core.GeneratorVector, c.max) H_[0] = c.H[0] for i := 1; i < c.max; i++ { - H_[i] = new(ristretto.Point).ScalarMult(c.H[i], new(ristretto.Scalar).Inverse(y_n[i])) + H_[i] = new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(y_n[i]),c.H[i],) } P := core.MultiExp(l.Join(r), c.G.Join(H_)) log.Debugf("P: %v", P) - uP := new(ristretto.Point).Add(P, new(ristretto.Point).ScalarMult(c.u, iplr)) + uP := new(ristretto.Element).Add(P, new(ristretto.Element).ScalarMult(iplr,c.u)) log.Debugf("uP: %v", uP) - ipp := ProveInnerProduct(l, r, c.u, new(ristretto.Point).Set(uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) + ipp := ProveInnerProduct(l, r, c.u, new(ristretto.Element).Add(new(ristretto.Element).Zero(),uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) - taux := new(ristretto.Scalar).Mul(tau2, new(ristretto.Scalar).Square(x)) - taux = taux.Add(taux, new(ristretto.Scalar).Mul(tau1, x)) - taux = taux.Add(taux, new(ristretto.Scalar).Mul(z2, gamma)) + taux := new(ristretto.Scalar) + taux.Multiply(tau2, new(ristretto.Scalar).Multiply(x,x)) + taux.Add(taux, new(ristretto.Scalar).Multiply(tau1, x)) + taux.Add(taux, new(ristretto.Scalar).Multiply(z2, gamma)) - mu := new(ristretto.Scalar).Add(alpha, new(ristretto.Scalar).Mul(p, x)) + mu := new(ristretto.Scalar) + mu.Add(alpha, new(ristretto.Scalar).Multiply(p, x)) return RangeProof{A, S, T1, T2, taux, iplr, mu, ipp, V}, gamma } // VerifyRangeProof returns true if the given proof passes all the checks for a given set of commitment parameters // and the given transcript func VerifyRangeProof(proof RangeProof, c CommitmentsParams, transcript *core.Transcript) bool { - one := new(ristretto.Scalar).SetOne() - two := new(ristretto.Scalar).Add(one, one) + one := core.IdentityVector(1)[0] + two := new(ristretto.Scalar) + two.Add(one, one) two_n := core.PowerVector(two, c.max) - transcript.AddToTranscript("A", proof.A.Bytes()) - transcript.AddToTranscript("S", proof.S.Bytes()) + transcript.AddToTranscript("A", []byte(proof.A.String())) + transcript.AddToTranscript("S",[]byte(proof.S.String())) y := transcript.CommitToTranscriptScalar("y") z := transcript.CommitToTranscriptScalar("z") - transcript.AddToTranscript("T1", proof.T1.Bytes()) - transcript.AddToTranscript("T2", proof.T2.Bytes()) + transcript.AddToTranscript("T1", []byte(proof.T1.String())) + transcript.AddToTranscript("T2", []byte(proof.T2.String())) x := transcript.CommitToTranscriptScalar("x") y_n := core.PowerVector(y, c.max) // generate h' H_ := make(core.GeneratorVector, c.max) H_[0] = c.H[0] for i := 1; i < c.max; i++ { - H_[i] = new(ristretto.Point).ScalarMult(c.H[i], new(ristretto.Scalar).Inverse(y_n[i])) + H_[i] = new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(y_n[i]),c.H[i],) } // check t = t(x) = t0 + t1.x + t1.x^2 lhs := core.MultiExp(core.ScalarVector{proof.InnerProduct, proof.TauX}, core.GeneratorVector{c.g, c.h}) - rhs := core.MultiExp(core.ScalarVector{new(ristretto.Scalar).Square(z), delta(y_n, z, c.max), x, new(ristretto.Scalar).Square(x)}, core.GeneratorVector{proof.V, c.g, proof.T1, proof.T2}) + + + z2 := new(ristretto.Scalar) + z2.Multiply(z,z) + + x2 := new(ristretto.Scalar) + x2.Multiply(x,x) + + rhs := core.MultiExp(core.ScalarVector{z2, delta(y_n, z, c.max), x, x2}, core.GeneratorVector{proof.V, c.g, proof.T1, proof.T2}) log.Debugf("lhs: %v", lhs) log.Debugf("rhs: %v", rhs) - log.Debugf("equal: %v", lhs.Equals(rhs)) + log.Debugf("equal: %v", lhs.Equal(rhs)) - if lhs.Equals(rhs) { + if lhs.Equal(rhs) == 1{ // compute P - negz := new(ristretto.Scalar).Neg(z) - negzG := new(ristretto.Point).SetZero() + negz := new(ristretto.Scalar).Negate(z) + negzG := new(ristretto.Element).Zero() for _, gen := range c.G { - negzG = negzG.Add(negzG, new(ristretto.Point).ScalarMult(gen, negz)) + negzG = negzG.Add(negzG, new(ristretto.Element).ScalarMult(negz, gen)) } - mul := core.EntrywiseSum(core.VectorMulScalar(y_n, z), core.VectorMulScalar(two_n, new(ristretto.Scalar).Square(z))) + mul := core.EntrywiseSum(core.VectorMulScalar(y_n, z), core.VectorMulScalar(two_n, new(ristretto.Scalar).Multiply(z,z))) - Pr := new(ristretto.Point).Set(proof.A) - Pr = Pr.Add(Pr, new(ristretto.Point).ScalarMult(proof.S, x)) + Pr := new(ristretto.Element).Add(new(ristretto.Element).Zero(),proof.A) + Pr = Pr.Add(Pr, new(ristretto.Element).ScalarMult(x,proof.S)) Pr = Pr.Add(Pr, negzG) Pr = Pr.Add(Pr, core.MultiExp(mul, H_)) - Pl := new(ristretto.Point).Sub(Pr, new(ristretto.Point).ScalarMult(c.h, proof.Mu)) + Pl := new(ristretto.Element).Subtract(Pr, new(ristretto.Element).ScalarMult(proof.Mu,c.h)) // check inner product - uP := new(ristretto.Point).Add(Pl, new(ristretto.Point).ScalarMult(c.u, proof.InnerProduct)) + uP := new(ristretto.Element).Add(Pl, new(ristretto.Element).ScalarMult(proof.InnerProduct, c.u)) - return Verify(proof.IPP, c.max, c.u, new(ristretto.Point).Set(uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) + return Verify(proof.IPP, c.max, c.u, new(ristretto.Element).Add(new(ristretto.Element).Zero(),uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) } return false } -func delta(y_n core.ScalarVector, z core.Scalar, max int) core.Scalar { - one := new(ristretto.Scalar).SetOne() +func delta(y_n core.ScalarVector, z *ristretto.Scalar, max int) *ristretto.Scalar { + one := core.IdentityVector(1)[0] // (z-z^2) - z2 := new(ristretto.Scalar).Square(z) - result := new(ristretto.Scalar).Sub(z, z2) + z2 := new(ristretto.Scalar).Multiply(z,z) + result := new(ristretto.Scalar) + result.Subtract(z, z2) // (z-z^2) * <1^n,y^n> - result.Mul(result, core.InnerProduct(core.IdentityVector(max), y_n)) - two := new(ristretto.Scalar).Add(one, one) + result.Multiply(result, core.InnerProduct(core.IdentityVector(max), y_n)) + two := new(ristretto.Scalar) + two.Add(one, one) two_n := core.PowerVector(two, max) // (z-z^2) * <1^n,y^n> - z^3 *<1n,2n> - z3 := new(ristretto.Scalar).Mul(z2, z) - return result.Sub(result, new(ristretto.Scalar).Mul(z3, core.InnerProduct(core.IdentityVector(max), two_n))) + z3 := new(ristretto.Scalar).Multiply(z2, z) + result.Subtract(result, new(ristretto.Scalar).Multiply(z3, core.InnerProduct(core.IdentityVector(max), two_n))) + return result } func valueToVector(value int32, max int) core.ScalarVector { - one := new(ristretto.Scalar).SetOne() - zero := new(ristretto.Scalar).SetZero() + one := core.IdentityVector(1)[0] + zero := new(ristretto.Scalar) + zero.Zero() result := core.ScalarVector{} for len(result) != max { v := value & 0x0001 diff --git a/primitives/bulletproofs/range_proof_test.go b/primitives/bulletproofs/range_proof_test.go index d89bdf1..cef9f9d 100644 --- a/primitives/bulletproofs/range_proof_test.go +++ b/primitives/bulletproofs/range_proof_test.go @@ -23,5 +23,4 @@ func TestProove(t *testing.T) { t.Logf("%v\n\n%v\n", proverTranscript.OutputTranscriptToAudit(), verifierTranscript.OutputTranscriptToAudit()) t.Fatalf("Failed to Verify Range Proof") } - } diff --git a/primitives/core/operations.go b/primitives/core/operations.go index 79dfb63..2f69dee 100644 --- a/primitives/core/operations.go +++ b/primitives/core/operations.go @@ -2,29 +2,27 @@ package core import ( "fmt" - "github.com/bwesterb/go-ristretto" + ristretto "github.com/gtank/ristretto255" ) -// Scalar is short hand for a ristretto scalar -type Scalar *ristretto.Scalar - -// Point is short hand of a ristretto point -type Point *ristretto.Point - // ScalarVector explicit type checking -type ScalarVector []Scalar +type ScalarVector []*ristretto.Scalar -// PointVector explicit type checking -type PointVector []*ristretto.Point +// ElementVector explicit type checking +type PointVector []*ristretto.Element // GeneratorVector explicit type checking -type GeneratorVector []*ristretto.Point +type GeneratorVector []*ristretto.Element + + + + // CopyVector safely copies a vector func CopyVector(G GeneratorVector) GeneratorVector { H := make(GeneratorVector, len(G)) for i, g := range G { - H[i] = new(ristretto.Point).Set(g) + H[i] = new(ristretto.Element).Add(new(ristretto.Element).Zero(), g) } return H } @@ -35,23 +33,23 @@ func InnerProduct(a, b ScalarVector) *ristretto.Scalar { panic(fmt.Sprintf("len(a) = %v ; len(b) = %v;", len(a), len(b))) } - result := new(ristretto.Scalar).SetZero() + result := new(ristretto.Scalar).Zero() for i, ai := range a { - result.Add(result, new(ristretto.Scalar).Mul(ai, b[i])) + result.Add(result, new(ristretto.Scalar).Multiply(ai, b[i])) } return result } // MultiExp takes in a vector of scalars = {a,b,c...} and a vector of generator = {A,B,C...} and outputs // {aA,bB,cC} -func MultiExp(a ScalarVector, G GeneratorVector) *ristretto.Point { +func MultiExp(a ScalarVector, G GeneratorVector) *ristretto.Element { if len(a) != len(G) { panic(fmt.Sprintf("len(a) = %v ; len(b) = %v;", len(a), len(G))) } - result := new(ristretto.Point).SetZero() + result := new(ristretto.Element).Zero() for i, ai := range a { - aG := new(ristretto.Point).ScalarMult(G[i], ai) - result = new(ristretto.Point).Add(result, aG) + aG := new(ristretto.Element).ScalarMult(ai, G[i]) + result = new(ristretto.Element).Add(result, aG) } return result } @@ -60,10 +58,10 @@ func MultiExp(a ScalarVector, G GeneratorVector) *ristretto.Point { func (a ScalarVector) Join(b ScalarVector) ScalarVector { list := make(ScalarVector, len(a)+len(b)) for i := 0; i < len(a); i++ { - list[i] = new(ristretto.Scalar).Set(a[i]) + list[i] = a[i] } for i := len(a); i < len(b)+len(b); i++ { - list[i] = new(ristretto.Scalar).Set(b[i-len(a)]) + list[i] = b[i-len(a)] } return list } @@ -72,10 +70,10 @@ func (a ScalarVector) Join(b ScalarVector) ScalarVector { func (a GeneratorVector) Join(b GeneratorVector) GeneratorVector { list := make(GeneratorVector, len(a)+len(b)) for i := 0; i < len(a); i++ { - list[i] = new(ristretto.Point).Set(a[i]) + list[i] = a[i] } for i := len(a); i < len(b)+len(b); i++ { - list[i] = new(ristretto.Point).Set(b[i-len(a)]) + list[i] = b[i-len(a)] } return list } @@ -84,7 +82,8 @@ func (a GeneratorVector) Join(b GeneratorVector) GeneratorVector { func VectorAddScalar(vector ScalarVector, scalar *ristretto.Scalar) ScalarVector { result := make(ScalarVector, len(vector)) for i := range vector { - result[i] = new(ristretto.Scalar).Add(vector[i], scalar) + result[i] = new(ristretto.Scalar) + result[i].Add(vector[i], scalar) } return result } @@ -93,7 +92,8 @@ func VectorAddScalar(vector ScalarVector, scalar *ristretto.Scalar) ScalarVector func VectorMulScalar(vector ScalarVector, scalar *ristretto.Scalar) ScalarVector { result := make(ScalarVector, len(vector)) for i := range vector { - result[i] = new(ristretto.Scalar).Mul(vector[i], scalar) + result[i] = new(ristretto.Scalar) + result[i].Multiply(vector[i], scalar) } return result } @@ -102,7 +102,8 @@ func VectorMulScalar(vector ScalarVector, scalar *ristretto.Scalar) ScalarVector func EntrywiseSum(vector ScalarVector, vector2 ScalarVector) ScalarVector { result := make(ScalarVector, len(vector)) for i, v := range vector { - result[i] = new(ristretto.Scalar).Add(v, vector2[i]) + result[i] = new(ristretto.Scalar) + result[i].Add(v, vector2[i]) } return result } @@ -111,7 +112,8 @@ func EntrywiseSum(vector ScalarVector, vector2 ScalarVector) ScalarVector { func EntryWiseProduct(vector ScalarVector, vector2 ScalarVector) ScalarVector { result := make(ScalarVector, len(vector)) for i, v := range vector { - result[i] = new(ristretto.Scalar).Mul(v, vector2[i]) + result[i] = new(ristretto.Scalar) + result[i].Multiply(v, vector2[i]) } return result } @@ -119,8 +121,10 @@ func EntryWiseProduct(vector ScalarVector, vector2 ScalarVector) ScalarVector { // IdentityVector is a convenience function to generate a vector v = {1,1,1...1} func IdentityVector(n int) ScalarVector { result := make(ScalarVector, n) + one := new(ristretto.Scalar) + one.Decode([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) for i := 0; i < n; i++ { - result[i] = new(ristretto.Scalar).SetOne() + result[i] = one } return result } @@ -128,10 +132,13 @@ func IdentityVector(n int) ScalarVector { // PowerVector creates a vector v = {1,x,x^2,x^3..x^n} func PowerVector(x *ristretto.Scalar, n int) ScalarVector { result := make(ScalarVector, n) - result[0] = new(ristretto.Scalar).SetOne() - result[1] = new(ristretto.Scalar).Set(x) - for i := 1; i < n; i++ { - result[i] = new(ristretto.Scalar).Mul(result[i-1], result[1]) + one := new(ristretto.Scalar) + one.Decode([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) + result[0] = one + result[1] = x + for i := 2; i < n; i++ { + result[i] = new(ristretto.Scalar) + result[i].Multiply(result[i-1], x) } return result -} +} \ No newline at end of file diff --git a/primitives/core/transcript.go b/primitives/core/transcript.go index 96471f8..b84e762 100644 --- a/primitives/core/transcript.go +++ b/primitives/core/transcript.go @@ -2,18 +2,17 @@ package core import ( "fmt" - "github.com/bwesterb/go-ristretto" + ristretto "github.com/gtank/ristretto255" "golang.org/x/crypto/sha3" "hash" "io" ) -// Transcript implements a transcript of a public coin argument. +// Transcript provides a consistent transcript primitive for our protocols // // We have the following goals: -// - Provide a consisted transcript API for our zero knowledge protocols // - Allow sequential proofs over a common transcript (ensuring a single proof cannot be extracted standalone) -// - produce an auditable human-readable transcript. +// - be able to produce a human-readable transcript for auditing. // // The design of this API was inspired by Merlin: https://docs.rs/crate/merlin/ // @@ -26,7 +25,7 @@ type Transcript struct { // NewTranscript creates a new Transcript with the given Label, the label should be unique to the application func NewTranscript(label string) *Transcript { transcript := new(Transcript) - transcript.hash = sha3.New256() + transcript.hash = sha3.New512() transcript.AddToTranscript("protocol", []byte(label)) return transcript } @@ -67,9 +66,11 @@ type PRNG struct { // Next returns the next "random" scalar from the PRNG func (prng *PRNG) Next() *ristretto.Scalar { - buf := [32]byte{} + buf := [64]byte{} io.ReadFull(prng.prng, buf[:]) - return new(ristretto.Scalar).SetBytes(&buf) + next := new(ristretto.Scalar) + next.FromUniformBytes(buf[:]) + return next } // CommitToPRNG commits the label to the transcript and derives a PRNG from the transcript. @@ -83,15 +84,13 @@ func (t *Transcript) CommitToPRNG(label string) PRNG { } // CommitToGenerator derives a verifiably random generator from the transcript -func (t *Transcript) CommitToGenerator(label string) *ristretto.Point { +func (t *Transcript) CommitToGenerator(label string) *ristretto.Element { c := t.CommitToTranscript(label) - cs := [32]byte{} - copy(cs[:], c[:]) - return new(ristretto.Point).SetElligator(&cs) + return new(ristretto.Element).FromUniformBytes(c) } // CommitToGenerators derives a set of verifiably random generators from the transcript -func (t *Transcript) CommitToGenerators(label string, n int) (generators []*ristretto.Point) { +func (t *Transcript) CommitToGenerators(label string, n int) (generators []*ristretto.Element) { for i := 0; i < n; i++ { generators = append(generators, t.CommitToGenerator(fmt.Sprintf("%v-%d", label, i))) } @@ -101,7 +100,7 @@ func (t *Transcript) CommitToGenerators(label string, n int) (generators []*rist // CommitToTranscriptScalar is a convenience method for CommitToTranscript which returns a ristretto Scalar func (t *Transcript) CommitToTranscriptScalar(label string) *ristretto.Scalar { c := t.CommitToTranscript(label) - cs := [32]byte{} - copy(cs[:], c[:]) - return new(ristretto.Scalar).SetBytes(&cs) + s := new(ristretto.Scalar) + s.FromUniformBytes(c[:]) + return s }