tapir/primitives/privacypass/tokenserver.go

172 lines
6.1 KiB
Go

package privacypass
import (
"crypto/hmac"
"crypto/rand"
"encoding/hex"
"fmt"
"git.openprivacy.ca/cwtch.im/tapir/persistence"
"git.openprivacy.ca/cwtch.im/tapir/primitives/core"
"git.openprivacy.ca/openprivacy/log"
ristretto "github.com/gtank/ristretto255"
"golang.org/x/crypto/sha3"
"sync"
)
// TokenServer implements a token server.
type TokenServer struct {
k *ristretto.Scalar
Y *ristretto.Element
seen map[string]bool
persistanceService persistence.Service
mutex sync.Mutex
}
// SignedBatchWithProof encapsulates a signed batch of blinded tokens with a batch proof for verification
type SignedBatchWithProof struct {
SignedTokens []SignedToken `json:"st"`
Proof DLEQProof `json:"dp"`
}
const tokenBucket = "tokens"
// NewTokenServer generates a new TokenServer (used mostly for testing with ephemeral instances)
func NewTokenServer() *TokenServer {
k := new(ristretto.Scalar)
b := make([]byte, 64)
_, err := rand.Read(b)
if err != nil {
// unable to generate secure random numbers
panic("unable to generate secure random numbers")
}
k.SetUniformBytes(b)
return &TokenServer{k, new(ristretto.Element).ScalarBaseMult(k), make(map[string]bool), nil, sync.Mutex{}}
}
// NewTokenServerFromStore generates a new TokenServer backed by a persistence service.
func NewTokenServerFromStore(k *ristretto.Scalar, persistenceService persistence.Service) *TokenServer {
tokenServer := NewTokenServer()
persistenceService.Setup([]string{tokenBucket})
// recalculate public key from k
tokenServer.k = k
tokenServer.Y = new(ristretto.Element).ScalarBaseMult(tokenServer.k)
tokenServer.persistanceService = persistenceService
return tokenServer
}
// Close ensures that the database is properly closed...
func (ts *TokenServer) Close() {
ts.mutex.Lock()
defer ts.mutex.Unlock()
ts.persistanceService.Close()
}
// SignBlindedToken calculates kP for the given BlindedToken P
func (ts *TokenServer) SignBlindedToken(bt BlindedToken) SignedToken {
Q := new(ristretto.Element).ScalarMult(ts.k, bt.P)
return SignedToken{Q}
}
// SignBlindedTokenBatch signs a batch of blinded tokens under a given transcript
func (ts *TokenServer) SignBlindedTokenBatch(blindedTokens []BlindedToken, transcript *core.Transcript) (*SignedBatchWithProof, error) {
var signedTokens []SignedToken
for _, bt := range blindedTokens {
signedTokens = append(signedTokens, ts.SignBlindedToken(bt))
}
proof, err := ts.constructBatchProof(blindedTokens, signedTokens, transcript)
if err != nil {
return nil, err
}
signedProof := SignedBatchWithProof{signedTokens, *proof}
return &signedProof, nil
}
// SignBlindedTokenBatchWithConstraint signs a batch of blinded tokens under a given transcript given a constraint that the tokens must be signed
// by the same public key as an existing token
func (ts *TokenServer) SignBlindedTokenBatchWithConstraint(blindedTokens []BlindedToken, constraintToken []byte, transcript *core.Transcript) (*SignedBatchWithProof, error) {
var signedTokens []SignedToken
for _, bt := range blindedTokens {
signedTokens = append(signedTokens, ts.SignBlindedToken(bt))
}
Ht := sha3.Sum512(constraintToken)
T, err := new(ristretto.Element).SetUniformBytes(Ht[:])
if err != nil {
return nil, err
}
// W == kT
W := new(ristretto.Element).ScalarMult(ts.k, T)
blindedTokens = append(blindedTokens, BlindedToken{P: T})
proof, err := ts.constructBatchProof(blindedTokens, append(signedTokens, SignedToken{Q: W}), transcript)
if err != nil {
return nil, err
}
signedProof := SignedBatchWithProof{signedTokens, *proof}
return &signedProof, nil
}
// constructBatchProof construct a batch proof that all the signed tokens have been signed correctly
func (ts *TokenServer) constructBatchProof(blindedTokens []BlindedToken, signedTokens []SignedToken, transcript *core.Transcript) (*DLEQProof, error) {
transcript.NewProtocol(BatchProofProtocol)
transcript.AddToTranscript(BatchProofX, ristretto.NewGeneratorElement().Bytes())
transcript.AddToTranscript(BatchProofY, ts.Y.Bytes())
transcript.AddToTranscript(BatchProofPVector, []byte(fmt.Sprintf("%v", blindedTokens)))
transcript.AddToTranscript(BatchProofQVector, []byte(fmt.Sprintf("%v", signedTokens)))
prng := transcript.CommitToPRNG("w")
M := ristretto.NewIdentityElement()
Z := ristretto.NewIdentityElement()
buf := make([]byte, 64)
c := new(ristretto.Scalar)
for i := range blindedTokens {
err := prng.Next(buf, c)
if err != nil {
log.Errorf("error constructing batch proof: %v", err)
return nil, err
}
M = new(ristretto.Element).Add(new(ristretto.Element).ScalarMult(c, blindedTokens[i].P), M)
Z = new(ristretto.Element).Add(new(ristretto.Element).ScalarMult(c, signedTokens[i].Q), Z)
}
proof := DiscreteLogEquivalenceProof(ts.k, ristretto.NewGeneratorElement(), ts.Y, M, Z, transcript)
return &proof, nil
}
// SpendToken returns true a SpentToken is valid and has never been spent before, false otherwise.
func (ts *TokenServer) SpendToken(token SpentToken, data []byte) error {
ts.mutex.Lock()
defer ts.mutex.Unlock() // We only want 1 client at a time redeeming tokens to prevent double-spends
if ts.persistanceService == nil {
if _, spent := ts.seen[hex.EncodeToString(token.T)]; spent {
return fmt.Errorf("token: %v has already been spent", token)
}
} else {
spent, err := ts.persistanceService.Check(tokenBucket, hex.EncodeToString(token.T))
if err != nil || spent {
return fmt.Errorf("token: %v has already been spent", token)
}
}
Ht := sha3.Sum512(token.T)
T, err := new(ristretto.Element).SetUniformBytes(Ht[:])
if err != nil {
return err
}
W := new(ristretto.Element).ScalarMult(ts.k, T)
key := sha3.Sum256(append(token.T, W.Bytes()...))
mac := hmac.New(sha3.New512, key[:])
mac.Write(data)
computedMAC := mac.Sum(nil)
result := hmac.Equal(token.MAC, computedMAC)
if result {
if ts.persistanceService == nil {
ts.seen[hex.EncodeToString(token.T)] = true
} else {
ts.persistanceService.Persist(tokenBucket, hex.EncodeToString(token.T), true)
}
return nil
}
return fmt.Errorf("token: %v is invalid and/or has not been signed by this service", token)
}