Add FetchMessagesFrom

This commit is contained in:
Sarah Jamie Lewis 2021-05-07 13:43:15 -07:00
parent 32d81046cf
commit cc4d5ee428
3 changed files with 85 additions and 27 deletions

View File

@ -2,8 +2,8 @@ package server
import ( import (
"cwtch.im/cwtch/protocol/groups" "cwtch.im/cwtch/protocol/groups"
"cwtch.im/cwtch/server/storage"
"encoding/json" "encoding/json"
"git.openprivacy.ca/cwtch.im/server/storage"
"git.openprivacy.ca/cwtch.im/tapir" "git.openprivacy.ca/cwtch.im/tapir"
"git.openprivacy.ca/cwtch.im/tapir/applications" "git.openprivacy.ca/cwtch.im/tapir/applications"
"git.openprivacy.ca/cwtch.im/tapir/primitives/privacypass" "git.openprivacy.ca/cwtch.im/tapir/primitives/privacypass"
@ -76,23 +76,25 @@ func (ta *TokenboardServer) Listen() {
case groups.ReplayRequestMessage: case groups.ReplayRequestMessage:
if message.ReplayRequest != nil { if message.ReplayRequest != nil {
log.Debugf("Received Replay Request %v", message.ReplayRequest) log.Debugf("Received Replay Request %v", message.ReplayRequest)
messages := ta.LegacyMessageStore.FetchMessages() messages := ta.LegacyMessageStore.FetchMessagesFrom(message.ReplayRequest.LastCommit)
response, _ := json.Marshal(groups.Message{MessageType: groups.ReplayResultMessage, ReplayResult: &groups.ReplayResult{NumMessages: len(messages)}}) response, _ := json.Marshal(groups.Message{MessageType: groups.ReplayResultMessage, ReplayResult: &groups.ReplayResult{NumMessages: len(messages)}})
log.Debugf("Sending Replay Response %v", groups.ReplayResult{NumMessages: len(messages)}) log.Debugf("Sending Replay Response %v", groups.ReplayResult{NumMessages: len(messages)})
ta.connection.Send(response) ta.connection.Send(response)
lastSignature := message.ReplayRequest.LastCommit
for _, message := range messages { for _, message := range messages {
lastSignature = message.Signature
data, _ = json.Marshal(message) data, _ = json.Marshal(message)
ta.connection.Send(data) ta.connection.Send(data)
} }
log.Debugf("Finished Requested Sync") log.Debugf("Finished Requested Sync")
// Set sync and then send any new messages that might have happened while we were syncing // Set sync and then send any new messages that might have happened while we were syncing
ta.connection.SetCapability(groups.CwtchServerSyncedCapability) ta.connection.SetCapability(groups.CwtchServerSyncedCapability)
newMessages := ta.LegacyMessageStore.FetchMessages() // Because we have set the sync capability any new messages that arrive after this point will just
if len(newMessages) > len(messages) { // need to do a basic lookup from the last seen message
for _, message := range newMessages[len(messages):] { newMessages := ta.LegacyMessageStore.FetchMessagesFrom(lastSignature)
data, _ = json.Marshal(groups.Message{MessageType: groups.NewMessageMessage, NewMessage: &groups.NewMessage{EGM: *message}}) for _, message := range newMessages[len(messages):] {
ta.connection.Send(data) data, _ = json.Marshal(groups.Message{MessageType: groups.NewMessageMessage, NewMessage: &groups.NewMessage{EGM: *message}})
} ta.connection.Send(data)
} }
} else { } else {
log.Debugf("Server Closing Connection Because of Malformed ReplayRequestMessage Packet") log.Debugf("Server Closing Connection Because of Malformed ReplayRequestMessage Packet")

View File

@ -13,15 +13,22 @@ import (
type MessageStoreInterface interface { type MessageStoreInterface interface {
AddMessage(groups.EncryptedGroupMessage) AddMessage(groups.EncryptedGroupMessage)
FetchMessages() []*groups.EncryptedGroupMessage FetchMessages() []*groups.EncryptedGroupMessage
FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage
} }
// SqliteMessageStore is an sqlite3 backed message store // SqliteMessageStore is an sqlite3 backed message store
type SqliteMessageStore struct { type SqliteMessageStore struct {
database *sql.DB database *sql.DB
// Some prepared queries...
preparedInsertStatement *sql.Stmt // A Stmt is safe for concurrent use by multiple goroutines.
preparedFetchFromQuery *sql.Stmt
} }
// Close closes the underlying sqlite3 database to further changes // Close closes the underlying sqlite3 database to further changes
func (s *SqliteMessageStore) Close() { func (s *SqliteMessageStore) Close() {
s.preparedInsertStatement.Close()
s.preparedFetchFromQuery.Close()
s.database.Close() s.database.Close()
} }
@ -32,16 +39,9 @@ func (s *SqliteMessageStore) AddMessage(message groups.EncryptedGroupMessage) {
log.Errorf("%q", err) log.Errorf("%q", err)
return return
} }
sqlStmt := `INSERT INTO messages(signature, ciphertext) values (?,?);` stmt, err := s.preparedInsertStatement.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext))
stmt, err := s.database.Prepare(sqlStmt)
if err != nil { if err != nil {
log.Errorf("%q: %s", err, sqlStmt) log.Errorf("%v %q", stmt, err)
return
}
defer stmt.Close()
_, err = stmt.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext))
if err != nil {
log.Errorf("%q: %s\n", err, sqlStmt)
return return
} }
tx.Commit() tx.Commit()
@ -55,12 +55,31 @@ func (s SqliteMessageStore) FetchMessages() []*groups.EncryptedGroupMessage {
return nil return nil
} }
defer rows.Close() defer rows.Close()
return s.compileRows(rows)
}
// FetchMessagesFrom implements the MessageStoreInterface FetchMessagesFrom for sqlite message store
func (s SqliteMessageStore) FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage {
if signature == nil {
return s.FetchMessages()
}
rows, err := s.preparedFetchFromQuery.Query(base64.StdEncoding.EncodeToString(signature))
if err != nil {
log.Errorf("%v", err)
return nil
}
defer rows.Close()
return s.compileRows(rows)
}
func (s *SqliteMessageStore) compileRows(rows *sql.Rows) []*groups.EncryptedGroupMessage {
var messages []*groups.EncryptedGroupMessage var messages []*groups.EncryptedGroupMessage
for rows.Next() { for rows.Next() {
var id int var id int
var signature string var signature string
var ciphertext string var ciphertext string
err = rows.Scan(&id, &signature, &ciphertext) err := rows.Scan(&id, &signature, &ciphertext)
if err != nil { if err != nil {
log.Errorf("Error fetching row %v", err) log.Errorf("Error fetching row %v", err)
} }
@ -92,5 +111,21 @@ func InitializeSqliteMessageStore(dbfile string) (*SqliteMessageStore, error) {
log.Infof("Database Initialized") log.Infof("Database Initialized")
slms := new(SqliteMessageStore) slms := new(SqliteMessageStore)
slms.database = db slms.database = db
sqlStmt = `INSERT INTO messages(signature, ciphertext) values (?,?);`
stmt, err := slms.database.Prepare(sqlStmt)
if err != nil {
log.Errorf("%q: %s", err, sqlStmt)
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
}
slms.preparedInsertStatement = stmt
query, err := slms.database.Prepare("SELECT id, signature,ciphertext FROM messages WHERE id>=(SELECT id FROM messages WHERE signature=(?));")
if err != nil {
log.Errorf("%v", err)
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
}
slms.preparedFetchFromQuery = query
return slms, nil return slms, nil
} }

View File

@ -2,6 +2,7 @@ package storage
import ( import (
"cwtch.im/cwtch/protocol/groups" "cwtch.im/cwtch/protocol/groups"
"encoding/binary"
"git.openprivacy.ca/openprivacy/log" "git.openprivacy.ca/openprivacy/log"
"os" "os"
"testing" "testing"
@ -15,22 +16,42 @@ func TestMessageStore(t *testing.T) {
t.Fatalf("Error: %v", err) t.Fatalf("Error: %v", err)
} }
db.AddMessage(groups.EncryptedGroupMessage{ numMessages := 100
Signature: []byte("Hello world 2"),
Ciphertext: []byte("Hello world"),
})
db.AddMessage(groups.EncryptedGroupMessage{ t.Logf("Populating Database")
Signature: []byte("Hello world 1"), for i := 0; i < numMessages; i++ {
Ciphertext: []byte("Hello world"), buf := make([]byte, 4)
}) binary.PutUvarint(buf, uint64(i))
db.AddMessage(groups.EncryptedGroupMessage{
Signature: append([]byte("Hello world"), buf...),
Ciphertext: []byte("Hello world"),
})
t.Logf("Inserted %v", i)
}
// Wait for inserts to complete..
messages := db.FetchMessages() messages := db.FetchMessages()
for _, message := range messages { for _, message := range messages {
t.Logf("Message: %v", message) t.Logf("Message: %v", message)
} }
if len(messages) != 2 { if len(messages) != numMessages {
t.Fatalf("Incorrect number of messages returned") t.Fatalf("Incorrect number of messages returned")
} }
t.Logf("Testing FetchMessagesFrom...")
numToFetch := numMessages / 2
buf := make([]byte, 4)
binary.PutUvarint(buf, uint64(numToFetch))
sig := append([]byte("Hello world"), buf...)
messages = db.FetchMessagesFrom(sig)
for _, message := range messages {
t.Logf("Message: %v", message)
}
if len(messages) != numToFetch {
t.Fatalf("Incorrect number of messages returned : %v", len(messages))
}
db.Close() db.Close()
} }