diff --git a/primitives/bulletproofs/constaint_system_test.go b/primitives/bulletproofs/constaint_system_test.go new file mode 100644 index 0000000..847a5c8 --- /dev/null +++ b/primitives/bulletproofs/constaint_system_test.go @@ -0,0 +1,112 @@ +package bulletproofs + +import ( + "crypto/sha512" + "cwtch.im/tapir/primitives/core" + "git.openprivacy.ca/openprivacy/libricochet-go/log" + ristretto "github.com/gtank/ristretto255" + "testing" +) + +func TestConstraintSystem(t *testing.T) { + log.SetLevel(log.LevelDebug) + cs := NewConstrainSystem(Setup(2, core.NewTranscript(""))) + + abytes := sha512.Sum512([]byte("a")) + a := new(ristretto.Scalar) + a.FromUniformBytes(abytes[:]) + + bbytes := sha512.Sum512([]byte("b")) + b := new(ristretto.Scalar) + b.FromUniformBytes(bbytes[:]) + + xbytes := sha512.Sum512([]byte("a")) + x := new(ristretto.Scalar) + x.FromUniformBytes(xbytes[:]) + + ybytes := sha512.Sum512([]byte("b")) + y := new(ristretto.Scalar) + y.FromUniformBytes(ybytes[:]) + + prng := core.NewTranscript("private").CommitToPRNG("private") + + V1, a_lc := cs.Commit(a, prng.Next()) + V2, b_lc := cs.Commit(b, prng.Next()) + V3, x_lc := cs.Commit(x, prng.Next()) + V4, y_lc := cs.Commit(y, prng.Next()) + + cs.VerifierCommit(V1) + cs.VerifierCommit(V2) + cs.VerifierCommit(V3) + cs.VerifierCommit(V4) + + _, _, in := cs.Multiply(a_lc, b_lc) + _, _, out := cs.Multiply(x_lc, y_lc) + cs.Constrain(in.Sub(out)) + + wL, wR, wO, wV := cs.flatten(core.One()) + + lhs := new(ristretto.Scalar) + lhs.Add(lhs, core.InnerProduct(cs.aL, wL)) + lhs.Add(lhs, core.InnerProduct(cs.aR, wR)) + + rhs := new(ristretto.Scalar) + rhs.Add(rhs, core.InnerProduct(cs.aO, wO)) + rrhs := new(ristretto.Scalar) + rrhs.Add(rrhs, core.InnerProduct(cs.v, wV)) + + t.Logf("lhs: %v\n", lhs) + t.Logf("rhs: %v\n", rhs) + t.Logf("rrhs: %v\n", rrhs) + t.Logf("\nwL: %v\nwR: %v\nwO: %v\nwC: %v", wL, wR, wO, wV) + + proof := cs.Prove(cs.params, core.NewTranscript("")) + + t.Logf("Proof Result: %v", cs.Verify(proof, cs.params, core.NewTranscript(""))) +} + +func TestConstraintSystemMix(t *testing.T) { + log.SetLevel(log.LevelDebug) + cs := NewConstrainSystem(Setup(2, core.NewTranscript(""))) + + one := core.One() + two := new(ristretto.Scalar).Add(one, one) + three := new(ristretto.Scalar).Add(two, one) + four := new(ristretto.Scalar).Add(two, two) + + prng := core.NewTranscript("private").CommitToPRNG("private") + + V1, a_lc := cs.Commit(three, prng.Next()) + V2, b_lc := cs.Commit(three, prng.Next()) + V3, x_lc := cs.Commit(four, prng.Next()) + V4, y_lc := cs.Commit(two, prng.Next()) + + cs.VerifierCommit(V1) + cs.VerifierCommit(V2) + cs.VerifierCommit(V3) + cs.VerifierCommit(V4) + + _, _, out := cs.Multiply(a_lc.Sub(x_lc).Add(b_lc.Sub(y_lc)), x_lc.Add(y_lc).Sub(a_lc.Add(b_lc))) + _, _, out2 := cs.Multiply(a_lc.Sub(x_lc).Add(b_lc.Sub(y_lc)), x_lc.Add(y_lc).Sub(a_lc.Add(b_lc))) + cs.Constrain(out.Add(out2)) + + wL, wR, wO, wV := cs.flatten(core.One()) + + lhs := new(ristretto.Scalar) + lhs.Add(lhs, core.InnerProduct(cs.aL, wL)) + lhs.Add(lhs, core.InnerProduct(cs.aR, wR)) + + rhs := new(ristretto.Scalar) + rhs.Add(rhs, core.InnerProduct(cs.aO, wO)) + rrhs := new(ristretto.Scalar) + rrhs.Add(rrhs, core.InnerProduct(cs.v, wV)) + + t.Logf("lhs: %v\n", lhs) + t.Logf("rhs: %v\n", rhs) + t.Logf("rrhs: %v\n", rrhs) + t.Logf("\nwL: %v\nwR: %v\nwO: %v\nwC: %v", wL, wR, wO, wV) + + proof := cs.Prove(cs.params, core.NewTranscript("")) + t.Logf("Proof Result: %v", cs.Verify(proof, cs.params, core.NewTranscript(""))) + +} diff --git a/primitives/bulletproofs/constraint_system.go b/primitives/bulletproofs/constraint_system.go new file mode 100644 index 0000000..dc0b5ed --- /dev/null +++ b/primitives/bulletproofs/constraint_system.go @@ -0,0 +1,424 @@ +package bulletproofs + +import ( + "crypto/rand" + "cwtch.im/tapir/primitives/core" + "git.openprivacy.ca/openprivacy/libricochet-go/log" + ristretto "github.com/gtank/ristretto255" +) + +type Variable struct { + Enum string + Index int +} + +type Term struct { + Variable + Coefficient *ristretto.Scalar +} + +// LinearCombination represent a vector of Terms +// We use the same structure as dalek-bulletproofs and reference the variable by an index into the canonical +// term in the constraint system. +type LinearCombination struct { + Terms []Term +} + +// Adding together linear combinations lc+olc = lc[:]+olc[:] +// (a * -1) + (b * -1) = (a*-1 + b * -1) +func (lc *LinearCombination) Add(olc *LinearCombination) *LinearCombination { + terms := lc.Terms + for _, term := range olc.Terms { + terms = append(terms, term) + } + return &LinearCombination{terms} +} + +// Subtracting a linear combination olc from lc results in the terms of lc being extended +// with those of lc, but with the coefficients in old being negated. +// (a * -1 - b * -1) = (a*-1 + b * 1) +func (lc *LinearCombination) Sub(olc *LinearCombination) *LinearCombination { + terms := lc.Terms + for _, term := range olc.Terms { + term.Coefficient = new(ristretto.Scalar).Negate(term.Coefficient) + terms = append(terms, term) + } + return &LinearCombination{terms} +} + +// ConstraintSystem allow constructing constraint systems over committed parameters, generating a proof of that constraint +// and allows a verifier to check that the constraint holds. +type ConstraintSystem struct { + aL core.ScalarVector + aR core.ScalarVector + aO core.ScalarVector + + v core.ScalarVector + vBlinding core.ScalarVector + constraints []*LinearCombination + + V core.GeneratorVector // todo clarify this + + params CommitmentsParams +} + +func NewConstrainSystem(params CommitmentsParams) ConstraintSystem { + cs := ConstraintSystem{} + cs.params = params + return cs +} + +// Multiply constructs a multiplication gate within our constraint system. +// It takes in two linear combinations a and b, evaluates the result c, and outputs three linear combinations +// one for the left input wire (a), one of the right input wire b, and one for the output wire (c) +// a and b are constrained to (a*-1) and (b*-1) +func (cs *ConstraintSystem) Multiply(a, b *LinearCombination) (*LinearCombination, *LinearCombination, *LinearCombination) { + + l := cs.eval(a) + r := cs.eval(b) + + c := new(ristretto.Scalar).Multiply(l, r) + + log.Debugf("%v & %v == %v", l, r, c) + + a.Terms = append(a.Terms, Term{Variable{"left", len(cs.aL)}, new(ristretto.Scalar).Negate(core.One())}) + b.Terms = append(b.Terms, Term{Variable{"right", len(cs.aR)}, new(ristretto.Scalar).Negate(core.One())}) + + o := &LinearCombination{[]Term{{Variable{"output", len(cs.aO)}, core.One()}}} + + cs.aL = append(cs.aL, l) + cs.aR = append(cs.aR, r) + cs.aO = append(cs.aO, c) + + cs.Constrain(a) + cs.Constrain(b) + + return a, b, o + +} + +// eval evaluates a linear combination lc, by looking up each term in the relevant vector in the constrain system +// and multiplying it by it's coefficient. +func (cs *ConstraintSystem) eval(lc *LinearCombination) *ristretto.Scalar { + result := new(ristretto.Scalar) + for _, term := range lc.Terms { + switch term.Enum { + case "left": + result.Add(result, new(ristretto.Scalar).Multiply(term.Coefficient, cs.aL[term.Index])) + case "right": + result.Add(result, new(ristretto.Scalar).Multiply(term.Coefficient, cs.aR[term.Index])) + case "output": + result.Add(result, new(ristretto.Scalar).Multiply(term.Coefficient, cs.aO[term.Index])) + case "committed": + result.Add(result, new(ristretto.Scalar).Multiply(term.Coefficient, cs.v[term.Index])) + case "one": + result.Add(result, term.Coefficient) + default: + panic("") + } + } + return result +} + +func (cs *ConstraintSystem) Commit(v *ristretto.Scalar, vBlind *ristretto.Scalar) (*ristretto.Element, *LinearCombination) { + i := len(cs.v) + cs.v = append(cs.v, v) + cs.vBlinding = append(cs.vBlinding, vBlind) + + V := core.MultiExp(core.ScalarVector{v, vBlind}, core.GeneratorVector{cs.params.g, cs.params.h}) + + return V, &LinearCombination{[]Term{{Variable{"committed", i}, core.One()}}} +} + +func (cs *ConstraintSystem) VerifierCommit(V *ristretto.Element) { + cs.V = append(cs.V, V) +} + +// Constrain adds the given linear combination to the constraints vector +// when constraints are flattened, each constraint will be evaluated and determine the +// weights to add. +func (cs *ConstraintSystem) Constrain(c *LinearCombination) { + cs.constraints = append(cs.constraints, c) +} + +// flatten constructs a set of 4 vectors of weights wL, wR, wO and wV representing the left inputs (L), right inputs (R), outputs (O) and +// committed values (C) such that +// + + = +func (cs *ConstraintSystem) flatten(z *ristretto.Scalar) (wL core.ScalarVector, wR core.ScalarVector, wO core.ScalarVector, wV core.ScalarVector) { + wL = make(core.ScalarVector, len(cs.aL)) + wR = make(core.ScalarVector, len(cs.aL)) + wO = make(core.ScalarVector, len(cs.aL)) + wV = make(core.ScalarVector, len(cs.v)) + + for i := 0; i < len(cs.aL); i++ { + wL[i] = new(ristretto.Scalar) + wR[i] = new(ristretto.Scalar) + wO[i] = new(ristretto.Scalar) + } + for i := 0; i < len(cs.v); i++ { + wV[i] = new(ristretto.Scalar) + } + + expZ := new(ristretto.Scalar).Add(z, new(ristretto.Scalar).Zero()) + for _, constraint := range cs.constraints { + for _, term := range constraint.Terms { + // log.Debugf("term: %v", term) + switch term.Enum { + case "left": + wL[term.Index].Add(wL[term.Index], new(ristretto.Scalar).Multiply(expZ, term.Coefficient)) + case "right": + wR[term.Index].Add(wR[term.Index], new(ristretto.Scalar).Multiply(expZ, term.Coefficient)) + case "output": + wO[term.Index].Add(wO[term.Index], new(ristretto.Scalar).Multiply(expZ, term.Coefficient)) + case "committed": + wV[term.Index].Subtract(wV[term.Index], new(ristretto.Scalar).Multiply(expZ, term.Coefficient)) + default: + panic("") + } + } + expZ = expZ.Multiply(expZ, z) + } + return +} + +type ConstraintProof struct { + Ai *ristretto.Element + Ao *ristretto.Element + S *ristretto.Element + T1 *ristretto.Element + T3 *ristretto.Element + T4 *ristretto.Element + T5 *ristretto.Element + T6 *ristretto.Element + Micro *ristretto.Scalar + TX *ristretto.Scalar + TX_Blinding *ristretto.Scalar + IPP InnerProductProof +} + +// Prove the constraint +func (cs *ConstraintSystem) Prove(c CommitmentsParams, transcript *core.Transcript) ConstraintProof { + + n := len(cs.aL) + + log.Debugf("n = %v\n", n) + + // Generate a prng to from this transcript and some external randomness + // We use this to generate the rest of our private scalars + // TODO: move to transcript + private := make([]byte, 64) + rand.Read(private) + prngTranscript := core.NewTranscript("private-transcript") + prngTranscript.AddToTranscript("randomness", private) + prng := prngTranscript.CommitToPRNG(transcript.OutputTranscriptToAudit()) + + alpha := prng.Next() + beta := prng.Next() + rho := prng.Next() + + Ai := core.MultiExp(append(cs.aL.Join(cs.aR), alpha), append(c.G.Join(c.H), c.h)) + Ao := core.MultiExp(cs.aO.SafeAppend(beta), c.G.SafeAppend(c.h)) + + Sl := make(core.ScalarVector, c.max) + Sr := make(core.ScalarVector, c.max) + for i := 0; i < c.max; i++ { + Sl[i] = prng.Next() + Sr[i] = prng.Next() + } + + S := core.MultiExp(append(Sl.Join(Sr), rho), c.G.Join(c.H).SafeAppend(c.h)) + + transcript.AddToTranscript("Ai", []byte(Ai.String())) + transcript.AddToTranscript("So", []byte(Ao.String())) + transcript.AddToTranscript("S", []byte(S.String())) + + y := transcript.CommitToTranscriptScalar("y") + yinv := new(ristretto.Scalar).Invert(y) + z := transcript.CommitToTranscriptScalar("z") + + log.Debugf("Calculating y^%v", n) + powY := core.PowerVector(y, n) + powYInv := core.PowerVector(yinv, n) + + wL, wR, wO, wV := cs.flatten(z) + + lhs := make([]core.ScalarVector, 4) + //lhs[0] = 0 + lhs[1] = core.EntrywiseSum(cs.aL, core.EntryWiseProduct(powYInv, wR)) + lhs[2] = cs.aO + lhs[3] = Sl + + rhs := make([]core.ScalarVector, 4) + rhs[0] = core.EntrywiseSub(wO, powY) + rhs[1] = core.EntrywiseSum(core.EntryWiseProduct(powY, cs.aR), wL) + // r2 = 0 + rhs[3] = core.EntryWiseProduct(powY, Sr) + + t1 := core.InnerProduct(lhs[1], rhs[0]) + t2 := new(ristretto.Scalar).Add(core.InnerProduct(lhs[1], rhs[1]), core.InnerProduct(lhs[2], rhs[0])) + t3 := new(ristretto.Scalar).Add(core.InnerProduct(lhs[2], rhs[1]), core.InnerProduct(lhs[3], rhs[0])) + t4 := new(ristretto.Scalar).Add(core.InnerProduct(lhs[1], rhs[3]), core.InnerProduct(lhs[3], rhs[1])) + t5 := core.InnerProduct(lhs[2], rhs[3]) + t6 := core.InnerProduct(lhs[3], rhs[3]) + + //t2 := core.InnerProduct(cs.aL,core.EntryWiseProduct(cs.aR, powY)) + //t2.Subtract(t2, core.InnerProduct(cs.aO, powY)) + //t2.Add(t2, ) + + // + tau1 := prng.Next() + tau3 := prng.Next() + tau4 := prng.Next() + tau5 := prng.Next() + tau6 := prng.Next() + + T1 := core.MultiExp(core.ScalarVector{t1, tau1}, core.GeneratorVector{c.g, c.h}) + T3 := core.MultiExp(core.ScalarVector{t3, tau3}, core.GeneratorVector{c.g, c.h}) + T4 := core.MultiExp(core.ScalarVector{t4, tau4}, core.GeneratorVector{c.g, c.h}) + T5 := core.MultiExp(core.ScalarVector{t5, tau5}, core.GeneratorVector{c.g, c.h}) + T6 := core.MultiExp(core.ScalarVector{t6, tau6}, core.GeneratorVector{c.g, c.h}) + + transcript.AddToTranscript("T1", []byte(T1.String())) + transcript.AddToTranscript("T3", []byte(T3.String())) + transcript.AddToTranscript("T4", []byte(T4.String())) + transcript.AddToTranscript("T5", []byte(T5.String())) + transcript.AddToTranscript("T6", []byte(T6.String())) + + //u := transcript.CommitToTranscriptScalar("u") + x := transcript.CommitToTranscriptScalar("x") + log.Debugf("x: %v", x) + + // T(X) = t0 + t1x + t2x + TX := new(ristretto.Scalar).Zero() + TX.Add(TX, new(ristretto.Scalar).Multiply(t1, x)) + x2 := new(ristretto.Scalar).Multiply(x, x) + TX.Add(TX, new(ristretto.Scalar).Multiply(t2, x2)) + x3 := new(ristretto.Scalar).Multiply(x2, x) + TX.Add(TX, new(ristretto.Scalar).Multiply(t3, x3)) + x4 := new(ristretto.Scalar).Multiply(x3, x) + TX.Add(TX, new(ristretto.Scalar).Multiply(t4, x4)) + x5 := new(ristretto.Scalar).Multiply(x4, x) + TX.Add(TX, new(ristretto.Scalar).Multiply(t5, x5)) + x6 := new(ristretto.Scalar).Multiply(x5, x) + TX.Add(TX, new(ristretto.Scalar).Multiply(t6, x6)) + + // evaluate l(X) + // lx0 = 0 + lx1 := core.VectorMulScalar(lhs[1], x) + lx2 := core.VectorMulScalar(lhs[2], x2) + lx3 := core.VectorMulScalar(lhs[3], new(ristretto.Scalar).Multiply(x2, x)) + lx := core.EntrywiseSum(core.EntrywiseSum(lx1, lx2), lx3) + + // evaluate r(X) + rx1 := core.VectorMulScalar(rhs[1], x) + // rx2 := core.VectorMulScalar(rhs[2],x2 ) rhs[2] == 0 + rx3 := core.VectorMulScalar(rhs[3], new(ristretto.Scalar).Multiply(x2, x)) + rx := core.EntrywiseSum(core.EntrywiseSum(rx1, rhs[0]), rx3) + + // calculate the inner product (t̂) + iplr := core.InnerProduct(lx, rx) + + // T(x) ?= + log.Debugf("T(X) = %v", TX) + log.Debugf("ipp = %v", iplr) + log.Debugf("equal: %v", TX.Equal(iplr) == 1) + + tau_blind := new(ristretto.Scalar).Multiply(tau1, x) + tau_blind.Add(tau_blind, new(ristretto.Scalar).Multiply(tau3, x3)) + tau_blind.Add(tau_blind, new(ristretto.Scalar).Multiply(tau4, x4)) + tau_blind.Add(tau_blind, new(ristretto.Scalar).Multiply(tau5, x5)) + tau_blind.Add(tau_blind, new(ristretto.Scalar).Multiply(tau6, x6)) + + //delta := core.InnerProduct(core.EntryWiseProduct(powYInv,wR), wL) + + tau_blind.Add(tau_blind, new(ristretto.Scalar).Multiply(core.InnerProduct(wV, cs.vBlinding), x2)) + + micro := new(ristretto.Scalar).Multiply(alpha, x) + micro.Add(micro, new(ristretto.Scalar).Multiply(beta, x2)) + micro.Add(micro, new(ristretto.Scalar).Multiply(rho, x3)) + + // generate h' + H_ := make(core.GeneratorVector, c.max) + H_[0] = c.H[0] + for i := 1; i < c.max; i++ { + H_[i] = new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(powY[i]), c.H[i]) + } + + P := core.MultiExp(lx.Join(rx), c.G.Join(H_)) + log.Debugf("P: %v", P) + + uP := new(ristretto.Element).Add(P, new(ristretto.Element).ScalarMult(iplr, c.u)) + log.Debugf("uP: %v", uP) + ipp := ProveInnerProduct(lx, rx, c.u, new(ristretto.Element).Add(new(ristretto.Element).Zero(), uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) + + return ConstraintProof{Ai, Ao, S, T1, T3, T4, T5, T6, micro, TX, tau_blind, ipp} +} + +// Verify the constraint +func (cs *ConstraintSystem) Verify(proof ConstraintProof, c CommitmentsParams, transcript *core.Transcript) bool { + + log.Debugf("Verifying: %v\n", proof) + + n := len(cs.aL) + transcript.AddToTranscript("Ai", []byte(proof.Ai.String())) + transcript.AddToTranscript("So", []byte(proof.Ao.String())) + transcript.AddToTranscript("S", []byte(proof.S.String())) + + y := transcript.CommitToTranscriptScalar("y") + yinv := new(ristretto.Scalar).Invert(y) + z := transcript.CommitToTranscriptScalar("z") + + log.Debugf("Calculating y^%v", n) + powY := core.PowerVector(y, n) + powYInv := core.PowerVector(yinv, n) + + wL, wR, wO, wV := cs.flatten(z) + + // generate h' + H_ := make(core.GeneratorVector, c.max) + H_[0] = c.H[0] + for i := 1; i < c.max; i++ { + H_[i] = new(ristretto.Element).ScalarMult(powYInv[i], c.H[i]) + } + + Wl := core.MultiExp(wL, H_) + Wr := core.MultiExp(core.EntryWiseProduct(powYInv, wR), c.G) + Wo := core.MultiExp(wO, H_) + + transcript.AddToTranscript("T1", []byte(proof.T1.String())) + transcript.AddToTranscript("T3", []byte(proof.T3.String())) + transcript.AddToTranscript("T4", []byte(proof.T4.String())) + transcript.AddToTranscript("T5", []byte(proof.T5.String())) + transcript.AddToTranscript("T6", []byte(proof.T6.String())) + + x := transcript.CommitToTranscriptScalar("x") + log.Debugf("x: %v", x) + x2 := new(ristretto.Scalar).Multiply(x, x) + x3 := new(ristretto.Scalar).Multiply(x2, x) + x4 := new(ristretto.Scalar).Multiply(x3, x) + x5 := new(ristretto.Scalar).Multiply(x4, x) + x6 := new(ristretto.Scalar).Multiply(x5, x) + + HY := core.MultiExp(core.VectorNegate(powY), H_) + P := core.MultiExp(core.ScalarVector{x, x2, x3, x, x, core.One()}, core.GeneratorVector{proof.Ai, proof.Ao, proof.S, Wl, Wr, Wo}) + P.Add(P, HY) + + P_ := P.Subtract(P, new(ristretto.Element).ScalarMult(proof.Micro, c.h)) + uP := new(ristretto.Element).Add(P_, new(ristretto.Element).ScalarMult(proof.TX, c.u)) + + log.Debugf("P: %v", P_) + log.Debugf("uP: %v", uP) + + lhs := core.MultiExp(core.ScalarVector{proof.TX, proof.TX_Blinding}, core.GeneratorVector{c.g, c.h}) + + delta := core.InnerProduct(core.EntryWiseProduct(powYInv, wR), wL) + + rhs := core.MultiExp(core.VectorMulScalar(wV, x2).Join(core.ScalarVector{x, x3, x4, x5, x6, new(ristretto.Scalar).Multiply(x2, delta)}), + cs.V.Join(core.GeneratorVector{proof.T1, proof.T3, proof.T4, proof.T5, proof.T6, c.g})) + + log.Debugf("lhs: %v", lhs) + log.Debugf("rhs: %v", rhs) + //rhs := core.MultiExp(,cs.) + + return lhs.Equal(rhs) == 1 && Verify(proof.IPP, n, c.u, new(ristretto.Element).Add(new(ristretto.Element).Zero(), uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) +} diff --git a/primitives/bulletproofs/inner_product.go b/primitives/bulletproofs/inner_product.go index c507bf2..d23fca1 100644 --- a/primitives/bulletproofs/inner_product.go +++ b/primitives/bulletproofs/inner_product.go @@ -42,7 +42,7 @@ func ProveInnerProduct(a, b core.ScalarVector, u *ristretto.Element, P *ristrett transcript.AddToTranscript("L", []byte(L.String())) Lvec = append(Lvec, L) - transcript.AddToTranscript("R",[]byte(R.String())) + transcript.AddToTranscript("R", []byte(R.String())) Rvec = append(Rvec, R) u := transcript.CommitToTranscriptScalar("u") @@ -59,10 +59,10 @@ func ProveInnerProduct(a, b core.ScalarVector, u *ristretto.Element, P *ristrett HL[i] = core.MultiExp(core.ScalarVector{u, uinv}, core.GeneratorVector{HL[i], HR[i]}) } - x2 := new(ristretto.Scalar).Multiply(u,u) - P_ := new(ristretto.Element).ScalarMult(x2,L) + x2 := new(ristretto.Scalar).Multiply(u, u) + P_ := new(ristretto.Element).ScalarMult(x2, L) P_.Add(P_, P) - P_.Add(P_, new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(x2),R)) + P_.Add(P_, new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(x2), R)) P = P_ transcript.AddToTranscript("P'", []byte(P.String())) @@ -105,10 +105,10 @@ func Verify(proof InnerProductProof, n int, u, P *ristretto.Element, G, H core.G HL[j] = core.MultiExp(core.ScalarVector{x, xinv}, core.GeneratorVector{HL[j], HR[j]}) } - x2 := new(ristretto.Scalar).Multiply(x,x) - P_ := new(ristretto.Element).ScalarMult(x2,proof.L[i]) + x2 := new(ristretto.Scalar).Multiply(x, x) + P_ := new(ristretto.Element).ScalarMult(x2, proof.L[i]) P_.Add(P_, P) - P_.Add(P_, new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(x2),proof.R[i])) + P_.Add(P_, new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(x2), proof.R[i])) P = P_ transcript.AddToTranscript("P'", []byte(P.String())) diff --git a/primitives/bulletproofs/inner_product_test.go b/primitives/bulletproofs/inner_product_test.go index d210fb3..fd62305 100644 --- a/primitives/bulletproofs/inner_product_test.go +++ b/primitives/bulletproofs/inner_product_test.go @@ -7,7 +7,6 @@ import ( "testing" ) - func assert(t *testing.T, expected *ristretto.Scalar, actual *ristretto.Scalar) { if expected.Equal(actual) == 1 { t.Logf("inner_product matched: %v", actual) @@ -30,7 +29,7 @@ func Test_inner_product(t *testing.T) { c = core.InnerProduct(a, b) check := new(ristretto.Scalar) - check.Add(one,one) + check.Add(one, one) assert(t, check, c) @@ -48,19 +47,18 @@ func TestProveInnerProduct(t *testing.T) { proverTranscript := core.NewTranscript("test_innerproductproof") verifierTranscript := core.NewTranscript("test_innerproductproof") - G := proverTranscript.CommitToGenerators("G",4) - H := proverTranscript.CommitToGenerators("H",4) + G := proverTranscript.CommitToGenerators("G", 4) + H := proverTranscript.CommitToGenerators("H", 4) u := proverTranscript.CommitToGenerator("u") - verifierTranscript.CommitToGenerators("G",4) - verifierTranscript.CommitToGenerators("H",4) + verifierTranscript.CommitToGenerators("G", 4) + verifierTranscript.CommitToGenerators("H", 4) verifierTranscript.CommitToGenerator("u") - c := core.InnerProduct(a,b) + c := core.InnerProduct(a, b) P_ := core.MultiExp(append(a.Join(b), c), append(core.GeneratorVector(G).Join(core.GeneratorVector(H)), u)) - proof := ProveInnerProduct(a, b, u, new(ristretto.Element).Add(new(ristretto.Element).Zero(), P_), core.CopyVector(G), core.CopyVector(H), proverTranscript) if Verify(proof, 4, u, P_, core.CopyVector(G), core.CopyVector(H), verifierTranscript) { diff --git a/primitives/bulletproofs/range_proof.go b/primitives/bulletproofs/range_proof.go index 7f4e89e..85ac52b 100644 --- a/primitives/bulletproofs/range_proof.go +++ b/primitives/bulletproofs/range_proof.go @@ -60,10 +60,10 @@ func GenerateRangeProof(value int32, c CommitmentsParams, transcript *core.Trans // Generate a prng to from this transcript and some external randomness // We use this to generate the rest of our private scalars // TODO: move to transcript - private := make([]byte,64) + private := make([]byte, 64) rand.Read(private) - prngTranscript:= core.NewTranscript("private-transcript") - prngTranscript.AddToTranscript("randomness",private) + prngTranscript := core.NewTranscript("private-transcript") + prngTranscript.AddToTranscript("randomness", private) prng := prngTranscript.CommitToPRNG(transcript.OutputTranscriptToAudit()) gamma := prng.Next() @@ -73,12 +73,10 @@ func GenerateRangeProof(value int32, c CommitmentsParams, transcript *core.Trans alpha := prng.Next() vs := new(ristretto.Scalar) - b := make([]byte,32) - copy(b,big.NewInt(int64(value)).Bytes()) + b := make([]byte, 32) + copy(b, big.NewInt(int64(value)).Bytes()) vs.Decode(b) - - V := core.MultiExp(core.ScalarVector{gamma, vs}, core.GeneratorVector{c.h, c.g}) A := core.MultiExp(append(aL.Join(aR), alpha), append(c.G.Join(c.H), c.h)) log.Debugf("vs: %v", vs) @@ -94,13 +92,13 @@ func GenerateRangeProof(value int32, c CommitmentsParams, transcript *core.Trans S := core.MultiExp(append(Sl.Join(Sr), p), append(c.G.Join(c.H), c.h)) transcript.AddToTranscript("A", []byte(A.String())) - transcript.AddToTranscript("S",[]byte(S.String())) + transcript.AddToTranscript("S", []byte(S.String())) y := transcript.CommitToTranscriptScalar("y") z := transcript.CommitToTranscriptScalar("z") y_n := core.PowerVector(y, c.max) - z2 := new(ristretto.Scalar).Multiply(z,z) + z2 := new(ristretto.Scalar).Multiply(z, z) l0 := core.VectorAddScalar(aL, new(ristretto.Scalar).Negate(z)) //l1 == Sr @@ -127,7 +125,7 @@ func GenerateRangeProof(value int32, c CommitmentsParams, transcript *core.Trans TX := new(ristretto.Scalar) TX.Add(new(ristretto.Scalar).Zero(), t0) TX.Add(TX, new(ristretto.Scalar).Multiply(t1, x)) - TX.Add(TX, new(ristretto.Scalar).Multiply(t2, new(ristretto.Scalar).Multiply(x,x))) + TX.Add(TX, new(ristretto.Scalar).Multiply(t2, new(ristretto.Scalar).Multiply(x, x))) l := core.EntrywiseSum(core.VectorAddScalar(aL, new(ristretto.Scalar).Negate(z)), core.VectorMulScalar(Sl, x)) _r := core.EntrywiseSum(core.VectorAddScalar(aR, z), core.VectorMulScalar(Sr, x)) @@ -143,18 +141,18 @@ func GenerateRangeProof(value int32, c CommitmentsParams, transcript *core.Trans H_ := make(core.GeneratorVector, c.max) H_[0] = c.H[0] for i := 1; i < c.max; i++ { - H_[i] = new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(y_n[i]),c.H[i],) + H_[i] = new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(y_n[i]), c.H[i]) } P := core.MultiExp(l.Join(r), c.G.Join(H_)) log.Debugf("P: %v", P) - uP := new(ristretto.Element).Add(P, new(ristretto.Element).ScalarMult(iplr,c.u)) + uP := new(ristretto.Element).Add(P, new(ristretto.Element).ScalarMult(iplr, c.u)) log.Debugf("uP: %v", uP) - ipp := ProveInnerProduct(l, r, c.u, new(ristretto.Element).Add(new(ristretto.Element).Zero(),uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) + ipp := ProveInnerProduct(l, r, c.u, new(ristretto.Element).Add(new(ristretto.Element).Zero(), uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) taux := new(ristretto.Scalar) - taux.Multiply(tau2, new(ristretto.Scalar).Multiply(x,x)) + taux.Multiply(tau2, new(ristretto.Scalar).Multiply(x, x)) taux.Add(taux, new(ristretto.Scalar).Multiply(tau1, x)) taux.Add(taux, new(ristretto.Scalar).Multiply(z2, gamma)) @@ -172,7 +170,7 @@ func VerifyRangeProof(proof RangeProof, c CommitmentsParams, transcript *core.Tr two_n := core.PowerVector(two, c.max) transcript.AddToTranscript("A", []byte(proof.A.String())) - transcript.AddToTranscript("S",[]byte(proof.S.String())) + transcript.AddToTranscript("S", []byte(proof.S.String())) y := transcript.CommitToTranscriptScalar("y") z := transcript.CommitToTranscriptScalar("z") transcript.AddToTranscript("T1", []byte(proof.T1.String())) @@ -183,25 +181,24 @@ func VerifyRangeProof(proof RangeProof, c CommitmentsParams, transcript *core.Tr H_ := make(core.GeneratorVector, c.max) H_[0] = c.H[0] for i := 1; i < c.max; i++ { - H_[i] = new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(y_n[i]),c.H[i],) + H_[i] = new(ristretto.Element).ScalarMult(new(ristretto.Scalar).Invert(y_n[i]), c.H[i]) } // check t = t(x) = t0 + t1.x + t1.x^2 lhs := core.MultiExp(core.ScalarVector{proof.InnerProduct, proof.TauX}, core.GeneratorVector{c.g, c.h}) - z2 := new(ristretto.Scalar) - z2.Multiply(z,z) + z2.Multiply(z, z) x2 := new(ristretto.Scalar) - x2.Multiply(x,x) + x2.Multiply(x, x) rhs := core.MultiExp(core.ScalarVector{z2, delta(y_n, z, c.max), x, x2}, core.GeneratorVector{proof.V, c.g, proof.T1, proof.T2}) log.Debugf("lhs: %v", lhs) log.Debugf("rhs: %v", rhs) log.Debugf("equal: %v", lhs.Equal(rhs)) - if lhs.Equal(rhs) == 1{ + if lhs.Equal(rhs) == 1 { // compute P @@ -211,19 +208,22 @@ func VerifyRangeProof(proof RangeProof, c CommitmentsParams, transcript *core.Tr negzG = negzG.Add(negzG, new(ristretto.Element).ScalarMult(negz, gen)) } - mul := core.EntrywiseSum(core.VectorMulScalar(y_n, z), core.VectorMulScalar(two_n, new(ristretto.Scalar).Multiply(z,z))) + mul := core.EntrywiseSum(core.VectorMulScalar(y_n, z), core.VectorMulScalar(two_n, new(ristretto.Scalar).Multiply(z, z))) - Pr := new(ristretto.Element).Add(new(ristretto.Element).Zero(),proof.A) - Pr = Pr.Add(Pr, new(ristretto.Element).ScalarMult(x,proof.S)) + Pr := new(ristretto.Element).Add(new(ristretto.Element).Zero(), proof.A) + Pr = Pr.Add(Pr, new(ristretto.Element).ScalarMult(x, proof.S)) Pr = Pr.Add(Pr, negzG) Pr = Pr.Add(Pr, core.MultiExp(mul, H_)) - Pl := new(ristretto.Element).Subtract(Pr, new(ristretto.Element).ScalarMult(proof.Mu,c.h)) + Pl := new(ristretto.Element).Subtract(Pr, new(ristretto.Element).ScalarMult(proof.Mu, c.h)) // check inner product uP := new(ristretto.Element).Add(Pl, new(ristretto.Element).ScalarMult(proof.InnerProduct, c.u)) - return Verify(proof.IPP, c.max, c.u, new(ristretto.Element).Add(new(ristretto.Element).Zero(),uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) + log.Debugf("P: %v", Pl) + log.Debugf("uP: %v", uP) + + return Verify(proof.IPP, c.max, c.u, new(ristretto.Element).Add(new(ristretto.Element).Zero(), uP), core.CopyVector(c.G), core.CopyVector(H_), transcript) } return false } @@ -231,7 +231,7 @@ func VerifyRangeProof(proof RangeProof, c CommitmentsParams, transcript *core.Tr func delta(y_n core.ScalarVector, z *ristretto.Scalar, max int) *ristretto.Scalar { one := core.IdentityVector(1)[0] // (z-z^2) - z2 := new(ristretto.Scalar).Multiply(z,z) + z2 := new(ristretto.Scalar).Multiply(z, z) result := new(ristretto.Scalar) result.Subtract(z, z2) // (z-z^2) * <1^n,y^n> diff --git a/primitives/core/operations.go b/primitives/core/operations.go index 2f69dee..790dec5 100644 --- a/primitives/core/operations.go +++ b/primitives/core/operations.go @@ -14,10 +14,6 @@ type PointVector []*ristretto.Element // GeneratorVector explicit type checking type GeneratorVector []*ristretto.Element - - - - // CopyVector safely copies a vector func CopyVector(G GeneratorVector) GeneratorVector { H := make(GeneratorVector, len(G)) @@ -43,7 +39,7 @@ func InnerProduct(a, b ScalarVector) *ristretto.Scalar { // MultiExp takes in a vector of scalars = {a,b,c...} and a vector of generator = {A,B,C...} and outputs // {aA,bB,cC} func MultiExp(a ScalarVector, G GeneratorVector) *ristretto.Element { - if len(a) != len(G) { + if len(a) > len(G) { panic(fmt.Sprintf("len(a) = %v ; len(b) = %v;", len(a), len(G))) } result := new(ristretto.Element).Zero() @@ -54,25 +50,45 @@ func MultiExp(a ScalarVector, G GeneratorVector) *ristretto.Element { return result } +// Join is defined for a vector of Scalars +func (a ScalarVector) SafeAppend(b *ristretto.Scalar) ScalarVector { + list := make(ScalarVector, len(a)+1) + for i := 0; i < len(a); i++ { + list[i] = a[i] + } + list[len(a)] = b + return list +} + // Join is defined for a vector of Scalars func (a ScalarVector) Join(b ScalarVector) ScalarVector { list := make(ScalarVector, len(a)+len(b)) for i := 0; i < len(a); i++ { list[i] = a[i] } - for i := len(a); i < len(b)+len(b); i++ { + for i := len(a); i < len(a)+len(b); i++ { list[i] = b[i-len(a)] } return list } +// Join as defined for a vector of Generators +func (a GeneratorVector) SafeAppend(b *ristretto.Element) GeneratorVector { + list := make(GeneratorVector, len(a)+1) + for i := 0; i < len(a); i++ { + list[i] = a[i] + } + list[len(a)] = b + return list +} + // Join as defined for a vector of Generators func (a GeneratorVector) Join(b GeneratorVector) GeneratorVector { list := make(GeneratorVector, len(a)+len(b)) for i := 0; i < len(a); i++ { list[i] = a[i] } - for i := len(a); i < len(b)+len(b); i++ { + for i := len(a); i < len(a)+len(b); i++ { list[i] = b[i-len(a)] } return list @@ -88,6 +104,15 @@ func VectorAddScalar(vector ScalarVector, scalar *ristretto.Scalar) ScalarVector return result } +// VectorNegate takes in a vector v = {a,b,c..} and a scalar s and outputs {-a,-b,-c} +func VectorNegate(vector ScalarVector) ScalarVector { + result := make(ScalarVector, len(vector)) + for i := range vector { + result[i] = new(ristretto.Scalar).Negate(vector[i]) + } + return result +} + // VectorMulScalar takes in a vector v = {a,b,c..} and a scalar s and outputs {as,bs,cs....} func VectorMulScalar(vector ScalarVector, scalar *ristretto.Scalar) ScalarVector { result := make(ScalarVector, len(vector)) @@ -108,6 +133,16 @@ func EntrywiseSum(vector ScalarVector, vector2 ScalarVector) ScalarVector { return result } +// EntrywiseSubtract takes the entry wise sum of two vectors +func EntrywiseSub(vector ScalarVector, vector2 ScalarVector) ScalarVector { + result := make(ScalarVector, len(vector)) + for i, v := range vector { + result[i] = new(ristretto.Scalar) + result[i].Subtract(v, vector2[i]) + } + return result +} + // EntryWiseProduct takes the entry wise product of two vectors func EntryWiseProduct(vector ScalarVector, vector2 ScalarVector) ScalarVector { result := make(ScalarVector, len(vector)) @@ -118,6 +153,13 @@ func EntryWiseProduct(vector ScalarVector, vector2 ScalarVector) ScalarVector { return result } +// One returns a ristretto scalar == 1 +func One() *ristretto.Scalar { + one := new(ristretto.Scalar) + one.Decode([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) + return one +} + // IdentityVector is a convenience function to generate a vector v = {1,1,1...1} func IdentityVector(n int) ScalarVector { result := make(ScalarVector, n) @@ -141,4 +183,4 @@ func PowerVector(x *ristretto.Scalar, n int) ScalarVector { result[i].Multiply(result[i-1], x) } return result -} \ No newline at end of file +}