tapir/primitives/privacypass/tokenserver.go

151 lines
5.8 KiB
Go

package privacypass
import (
"crypto/hmac"
"crypto/rand"
"cwtch.im/tapir/persistence"
"cwtch.im/tapir/primitives/core"
"encoding/hex"
"fmt"
ristretto "github.com/gtank/ristretto255"
"golang.org/x/crypto/sha3"
"sync"
)
// TokenServer implements a token server.
type TokenServer struct {
k *ristretto.Scalar
Y *ristretto.Element
seen map[string]bool
persistanceService persistence.Service
mutex sync.Mutex
}
// SignedBatchWithProof encapsulates a signed batch of blinded tokens with a batch proof for verification
type SignedBatchWithProof struct {
SignedTokens []SignedToken `json:"st"`
Proof DLEQProof `json:"dp"`
}
const tokenBucket = "tokens"
const keyBucket = "keys"
// NewTokenServer generates a new TokenServer (used mostly for testing with ephemeral instances)
func NewTokenServer() *TokenServer {
k := new(ristretto.Scalar)
b := make([]byte, 64)
_, err := rand.Read(b)
if err != nil {
// unable to generate secure random numbers
panic("unable to generate secure random numbers")
}
k.FromUniformBytes(b)
return &TokenServer{k, new(ristretto.Element).ScalarBaseMult(k), make(map[string]bool), nil, sync.Mutex{}}
}
// NewTokenServerFromStore generates a new TokenServer backed by a persistence service.
func NewTokenServerFromStore(persistenceService persistence.Service) *TokenServer {
tokenServer := NewTokenServer()
persistenceService.Setup([]string{tokenBucket})
persistenceService.Setup([]string{keyBucket})
exists, err := persistenceService.Check(keyBucket, "k")
if err != nil {
panic(err)
}
// if we don't have a stored k then save the one we have generated
// otherwise use the k we have stored
if !exists {
persistenceService.Persist(keyBucket, "k", tokenServer.k)
} else {
persistenceService.Load(keyBucket, "k", tokenServer.k)
// recalculate public key from stored k
tokenServer.Y = new(ristretto.Element).ScalarBaseMult(tokenServer.k)
}
tokenServer.persistanceService = persistenceService
return tokenServer
}
// SignBlindedToken calculates kP for the given BlindedToken P
func (ts *TokenServer) SignBlindedToken(bt BlindedToken) SignedToken {
Q := new(ristretto.Element).ScalarMult(ts.k, bt.P)
return SignedToken{Q}
}
// SignBlindedTokenBatch signs a batch of blinded tokens under a given transcript
func (ts *TokenServer) SignBlindedTokenBatch(blindedTokens []BlindedToken, transcript *core.Transcript) SignedBatchWithProof {
var signedTokens []SignedToken
for _, bt := range blindedTokens {
signedTokens = append(signedTokens, ts.SignBlindedToken(bt))
}
return SignedBatchWithProof{signedTokens, ts.constructBatchProof(blindedTokens, signedTokens, transcript)}
}
// SignBlindedTokenBatchWithConstraint signs a batch of blinded tokens under a given transcript given a constraint that the tokens must be signed
// by the same public key as an existing token
func (ts *TokenServer) SignBlindedTokenBatchWithConstraint(blindedTokens []BlindedToken, constraintToken []byte, transcript *core.Transcript) SignedBatchWithProof {
var signedTokens []SignedToken
for _, bt := range blindedTokens {
signedTokens = append(signedTokens, ts.SignBlindedToken(bt))
}
Ht := sha3.Sum512(constraintToken)
T := new(ristretto.Element).FromUniformBytes(Ht[:])
// W == kT
W := new(ristretto.Element).ScalarMult(ts.k, T)
blindedTokens = append(blindedTokens, BlindedToken{P: T})
return SignedBatchWithProof{signedTokens, ts.constructBatchProof(blindedTokens, append(signedTokens, SignedToken{Q: W}), transcript)}
}
// constructBatchProof construct a batch proof that all the signed tokens have been signed correctly
func (ts *TokenServer) constructBatchProof(blindedTokens []BlindedToken, signedTokens []SignedToken, transcript *core.Transcript) DLEQProof {
transcript.NewProtocol(BatchProofProtocol)
transcript.AddToTranscript(BatchProofX, new(ristretto.Element).Base().Encode(nil))
transcript.AddToTranscript(BatchProofY, ts.Y.Encode(nil))
transcript.AddToTranscript(BatchProofPVector, []byte(fmt.Sprintf("%v", blindedTokens)))
transcript.AddToTranscript(BatchProofQVector, []byte(fmt.Sprintf("%v", signedTokens)))
prng := transcript.CommitToPRNG("w")
M := new(ristretto.Element).Zero()
Z := new(ristretto.Element).Zero()
for i := range blindedTokens {
c := prng.Next()
M = new(ristretto.Element).Add(new(ristretto.Element).ScalarMult(c, blindedTokens[i].P), M)
Z = new(ristretto.Element).Add(new(ristretto.Element).ScalarMult(c, signedTokens[i].Q), Z)
}
return DiscreteLogEquivalenceProof(ts.k, new(ristretto.Element).Base(), ts.Y, M, Z, transcript)
}
// SpendToken returns true a SpentToken is valid and has never been spent before, false otherwise.
func (ts *TokenServer) SpendToken(token SpentToken, data []byte) error {
ts.mutex.Lock()
defer ts.mutex.Unlock() // We only want 1 client at a time redeeming tokens to prevent double-spends
if ts.persistanceService == nil {
if _, spent := ts.seen[hex.EncodeToString(token.T)]; spent {
return fmt.Errorf("token: %v has already been spent", token)
}
} else {
spent, err := ts.persistanceService.Check(tokenBucket, hex.EncodeToString(token.T))
if err != nil || spent == true {
return fmt.Errorf("token: %v has already been spent", token)
}
}
Ht := sha3.Sum512(token.T)
T := new(ristretto.Element).FromUniformBytes(Ht[:])
W := new(ristretto.Element).ScalarMult(ts.k, T)
key := sha3.Sum256(append(token.T, W.Encode(nil)...))
mac := hmac.New(sha3.New512, key[:])
computedMAC := mac.Sum(data)
result := hmac.Equal(token.MAC, computedMAC)
if result == true {
if ts.persistanceService == nil {
ts.seen[hex.EncodeToString(token.T)] = true
} else {
ts.persistanceService.Persist(tokenBucket, hex.EncodeToString(token.T), true)
}
return nil
}
return fmt.Errorf("token: %v is invalid and/or has not been signed by this service", token)
}