Reduce new allocations, propagate errors in PRNG.Next()
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/pr Build is passing Details

This commit is contained in:
Sarah Jamie Lewis 2022-04-19 15:16:07 -07:00
parent 4cce393e12
commit 3b4fee2e72
7 changed files with 78 additions and 27 deletions

View File

@ -5,6 +5,7 @@ import (
"git.openprivacy.ca/cwtch.im/tapir"
"git.openprivacy.ca/cwtch.im/tapir/primitives/core"
"git.openprivacy.ca/openprivacy/log"
ristretto "github.com/gtank/ristretto255"
)
// ProofOfWorkApplication forces the incoming connection to do proof of work before granting a capability
@ -66,9 +67,18 @@ func (powapp *ProofOfWorkApplication) solveChallenge(challenge []byte, prng core
var sum [32]byte
solution := []byte{}
solve := make([]byte, len(challenge)+32)
// reuse our allocation
buf := make([]byte, 64)
next := new(ristretto.Scalar)
encodedSolution := make([]byte, 0, 32)
for !solved {
solution = prng.Next().Encode(nil)
err := prng.Next(buf, next)
if err != nil {
// this will cause the challenge to fail...
log.Errorf("error completing challenge: %v", err)
return nil
}
solution = next.Encode(encodedSolution)
copy(solve[0:], solution[:])
copy(solve[len(solution):], challenge[:])
@ -80,6 +90,8 @@ func (powapp *ProofOfWorkApplication) solveChallenge(challenge []byte, prng core
solved = false
}
}
// reuse this allocated memory next time...
encodedSolution = encodedSolution[:0]
}
log.Debugf("Validated Challenge %v: %v %v\n", challenge, solution, sum)
return solution[:]

View File

@ -55,7 +55,10 @@ func (tokenapp *TokenApplication) Init(connection tapir.Connection) {
var blinded []privacypass.BlindedToken
err := json.Unmarshal(connection.Expect(), &blinded)
if err == nil {
batchProof := tokenapp.TokenService.SignBlindedTokenBatch(blinded, tokenapp.Transcript())
batchProof, err := tokenapp.TokenService.SignBlindedTokenBatch(blinded, tokenapp.Transcript())
if err != nil {
return
}
log.Debugf(tokenapp.Transcript().OutputTranscriptToAudit())
data, _ := json.Marshal(batchProof)
connection.Send(data)

View File

@ -2,6 +2,7 @@ package core
import (
"fmt"
"git.openprivacy.ca/openprivacy/log"
"github.com/gtank/merlin"
ristretto "github.com/gtank/ristretto255"
"golang.org/x/crypto/sha3"
@ -67,12 +68,14 @@ type PRNG struct {
}
// Next returns the next "random" scalar from the PRNG
func (prng *PRNG) Next() *ristretto.Scalar {
buf := [64]byte{}
io.ReadFull(prng.prng, buf[:])
next := new(ristretto.Scalar)
next.FromUniformBytes(buf[:])
return next
func (prng *PRNG) Next(buf []byte, next *ristretto.Scalar) error {
n, err := io.ReadFull(prng.prng, buf)
if n != 64 || err != nil {
log.Errorf("could not read prng: %v %v", n, err)
return fmt.Errorf("error fetching complete output from prng: %v", err)
}
next.FromUniformBytes(buf)
return nil
}
// CommitToPRNG commits the label to the transcript and derives a PRNG from the transcript.

View File

@ -95,8 +95,14 @@ func verifyBatchProof(dleq DLEQProof, Y *ristretto.Element, blindedTokens []Blin
prng := transcript.CommitToPRNG("w")
M := new(ristretto.Element).Zero()
Z := new(ristretto.Element).Zero()
buf := make([]byte, 64)
c := new(ristretto.Scalar)
for i := range blindedTokens {
c := prng.Next()
err := prng.Next(buf, c)
if err != nil {
log.Errorf("error verifying batch proof: %v", err)
return false
}
M = new(ristretto.Element).Add(new(ristretto.Element).ScalarMult(c, blindedTokens[i].P), M)
Z = new(ristretto.Element).Add(new(ristretto.Element).ScalarMult(c, signedTokens[i].Q), Z)
}

View File

@ -56,7 +56,12 @@ func TestToken_ConstrainToToken(t *testing.T) {
// Constraint forces T = kW to be part of the batch proof
// And because the batch proof must prove that *all* inputs share the same key and also checks the servers public key
// We get a consistency check for almost free.
signedTokens := server.SignBlindedTokenBatchWithConstraint([]BlindedToken{blindedToken2}, token.t, core.NewTranscript(""))
signedTokens, err := server.SignBlindedTokenBatchWithConstraint([]BlindedToken{blindedToken2}, token.t, core.NewTranscript(""))
if err != nil {
t.Fatalf("error signing tokens with constraints")
}
transcript := core.NewTranscript("")
// NOTE: For this to work token.t and token.W need to be obtain by the client from known source e.g. a public message board.
@ -78,7 +83,10 @@ func TestGenerateBlindedTokenBatch(t *testing.T) {
serverTranscript := core.NewTranscript("privacyPass")
tokens, blindedTokens := GenerateBlindedTokenBatch(10)
batchProof := server.SignBlindedTokenBatch(blindedTokens, serverTranscript)
batchProof, err := server.SignBlindedTokenBatch(blindedTokens, serverTranscript)
if err != nil {
t.Fatalf("error constructing signed/blinded token batch: %v", err)
}
verified := UnblindSignedTokenBatch(tokens, blindedTokens, batchProof.SignedTokens, server.Y, batchProof.Proof, clientTranscript)

View File

@ -7,6 +7,7 @@ import (
"fmt"
"git.openprivacy.ca/cwtch.im/tapir/persistence"
"git.openprivacy.ca/cwtch.im/tapir/primitives/core"
"git.openprivacy.ca/openprivacy/log"
ristretto "github.com/gtank/ristretto255"
"golang.org/x/crypto/sha3"
"sync"
@ -69,17 +70,23 @@ func (ts *TokenServer) SignBlindedToken(bt BlindedToken) SignedToken {
}
// SignBlindedTokenBatch signs a batch of blinded tokens under a given transcript
func (ts *TokenServer) SignBlindedTokenBatch(blindedTokens []BlindedToken, transcript *core.Transcript) SignedBatchWithProof {
func (ts *TokenServer) SignBlindedTokenBatch(blindedTokens []BlindedToken, transcript *core.Transcript) (*SignedBatchWithProof, error) {
var signedTokens []SignedToken
for _, bt := range blindedTokens {
signedTokens = append(signedTokens, ts.SignBlindedToken(bt))
}
return SignedBatchWithProof{signedTokens, ts.constructBatchProof(blindedTokens, signedTokens, transcript)}
proof, err := ts.constructBatchProof(blindedTokens, signedTokens, transcript)
if err != nil {
return nil, err
}
signedProof := SignedBatchWithProof{signedTokens, *proof}
return &signedProof, nil
}
// SignBlindedTokenBatchWithConstraint signs a batch of blinded tokens under a given transcript given a constraint that the tokens must be signed
// by the same public key as an existing token
func (ts *TokenServer) SignBlindedTokenBatchWithConstraint(blindedTokens []BlindedToken, constraintToken []byte, transcript *core.Transcript) SignedBatchWithProof {
func (ts *TokenServer) SignBlindedTokenBatchWithConstraint(blindedTokens []BlindedToken, constraintToken []byte, transcript *core.Transcript) (*SignedBatchWithProof, error) {
var signedTokens []SignedToken
for _, bt := range blindedTokens {
signedTokens = append(signedTokens, ts.SignBlindedToken(bt))
@ -89,11 +96,16 @@ func (ts *TokenServer) SignBlindedTokenBatchWithConstraint(blindedTokens []Blind
// W == kT
W := new(ristretto.Element).ScalarMult(ts.k, T)
blindedTokens = append(blindedTokens, BlindedToken{P: T})
return SignedBatchWithProof{signedTokens, ts.constructBatchProof(blindedTokens, append(signedTokens, SignedToken{Q: W}), transcript)}
proof, err := ts.constructBatchProof(blindedTokens, append(signedTokens, SignedToken{Q: W}), transcript)
if err != nil {
return nil, err
}
signedProof := SignedBatchWithProof{signedTokens, *proof}
return &signedProof, nil
}
// constructBatchProof construct a batch proof that all the signed tokens have been signed correctly
func (ts *TokenServer) constructBatchProof(blindedTokens []BlindedToken, signedTokens []SignedToken, transcript *core.Transcript) DLEQProof {
func (ts *TokenServer) constructBatchProof(blindedTokens []BlindedToken, signedTokens []SignedToken, transcript *core.Transcript) (*DLEQProof, error) {
transcript.NewProtocol(BatchProofProtocol)
transcript.AddToTranscript(BatchProofX, new(ristretto.Element).Base().Encode(nil))
transcript.AddToTranscript(BatchProofY, ts.Y.Encode(nil))
@ -104,12 +116,19 @@ func (ts *TokenServer) constructBatchProof(blindedTokens []BlindedToken, signedT
M := new(ristretto.Element).Zero()
Z := new(ristretto.Element).Zero()
buf := make([]byte, 64)
c := new(ristretto.Scalar)
for i := range blindedTokens {
c := prng.Next()
err := prng.Next(buf, c)
if err != nil {
log.Errorf("error constructing batch proof: %v", err)
return nil, err
}
M = new(ristretto.Element).Add(new(ristretto.Element).ScalarMult(c, blindedTokens[i].P), M)
Z = new(ristretto.Element).Add(new(ristretto.Element).ScalarMult(c, signedTokens[i].Q), Z)
}
return DiscreteLogEquivalenceProof(ts.k, new(ristretto.Element).Base(), ts.Y, M, Z, transcript)
proof := DiscreteLogEquivalenceProof(ts.k, new(ristretto.Element).Base(), ts.Y, M, Z, transcript)
return &proof, nil
}
// SpendToken returns true a SpentToken is valid and has never been spent before, false otherwise.

View File

@ -62,6 +62,7 @@ type connection struct {
MaxLength int
lock sync.Mutex
service Service
expectBuffer []byte
}
// NewConnection creates a new Connection
@ -74,7 +75,7 @@ func NewConnection(service Service, id *primitives.Identity, hostname string, ou
connection.outbound = outbound
connection.MaxLength = 8192
connection.service = service
connection.expectBuffer = make([]byte, 8192)
go connection.app.Init(connection)
return connection
}
@ -153,10 +154,9 @@ func (c *connection) closeInner() {
// Expect blocks and reads a single Tapir packet , from the connection.
func (c *connection) Expect() []byte {
buffer := make([]byte, c.MaxLength)
// Multiple goroutines may invoke methods on a Conn simultaneously.
// As such we don't need to mutex around closed.
n, err := io.ReadFull(c.conn, buffer)
n, err := io.ReadFull(c.conn, c.expectBuffer)
if n != c.MaxLength || err != nil {
log.Debugf("[%v -> %v] Wire Error Reading, Read %d bytes, Error: %v", c.hostname, c.identity.Hostname(), n, err)
@ -167,22 +167,22 @@ func (c *connection) Expect() []byte {
defer c.lock.Unlock()
if c.encrypted {
var decryptNonce [24]byte
copy(decryptNonce[:], buffer[:24])
decrypted, ok := secretbox.Open(nil, buffer[24:], &decryptNonce, &c.key)
copy(decryptNonce[:], c.expectBuffer[:24])
decrypted, ok := secretbox.Open(nil, c.expectBuffer[24:], &decryptNonce, &c.key)
if ok {
copy(buffer, decrypted)
copy(c.expectBuffer, decrypted)
} else {
log.Errorf("[%v -> %v] Error Decrypting Message On Wire", c.hostname, c.identity.Hostname())
c.closeInner()
return []byte{}
}
}
length, _ := binary.Uvarint(buffer[0:2])
length, _ := binary.Uvarint(c.expectBuffer[0:2])
if length+2 >= uint64(c.MaxLength) {
return []byte{}
}
//cplog.Debugf("[%v -> %v] Wire Receive: (%d) %x", c.hostname, c.ID.Hostname(), len, buffer)
return buffer[2 : length+2]
return c.expectBuffer[2 : length+2]
}
// SetEncryptionKey turns on application-level encryption on the connection using the given key.