From cc4d5ee4285ace839ab56997bde064c8efb193eb Mon Sep 17 00:00:00 2001 From: Sarah Jamie Lewis Date: Fri, 7 May 2021 13:43:15 -0700 Subject: [PATCH] Add FetchMessagesFrom --- server_tokenboard.go | 18 +++++++----- storage/message_store.go | 55 ++++++++++++++++++++++++++++------- storage/message_store_test.go | 39 +++++++++++++++++++------ 3 files changed, 85 insertions(+), 27 deletions(-) diff --git a/server_tokenboard.go b/server_tokenboard.go index 1b14272..f30eb4b 100644 --- a/server_tokenboard.go +++ b/server_tokenboard.go @@ -2,8 +2,8 @@ package server import ( "cwtch.im/cwtch/protocol/groups" - "cwtch.im/cwtch/server/storage" "encoding/json" + "git.openprivacy.ca/cwtch.im/server/storage" "git.openprivacy.ca/cwtch.im/tapir" "git.openprivacy.ca/cwtch.im/tapir/applications" "git.openprivacy.ca/cwtch.im/tapir/primitives/privacypass" @@ -76,23 +76,25 @@ func (ta *TokenboardServer) Listen() { case groups.ReplayRequestMessage: if message.ReplayRequest != nil { log.Debugf("Received Replay Request %v", message.ReplayRequest) - messages := ta.LegacyMessageStore.FetchMessages() + messages := ta.LegacyMessageStore.FetchMessagesFrom(message.ReplayRequest.LastCommit) response, _ := json.Marshal(groups.Message{MessageType: groups.ReplayResultMessage, ReplayResult: &groups.ReplayResult{NumMessages: len(messages)}}) log.Debugf("Sending Replay Response %v", groups.ReplayResult{NumMessages: len(messages)}) ta.connection.Send(response) + lastSignature := message.ReplayRequest.LastCommit for _, message := range messages { + lastSignature = message.Signature data, _ = json.Marshal(message) ta.connection.Send(data) } log.Debugf("Finished Requested Sync") // Set sync and then send any new messages that might have happened while we were syncing ta.connection.SetCapability(groups.CwtchServerSyncedCapability) - newMessages := ta.LegacyMessageStore.FetchMessages() - if len(newMessages) > len(messages) { - for _, message := range newMessages[len(messages):] { - data, _ = json.Marshal(groups.Message{MessageType: groups.NewMessageMessage, NewMessage: &groups.NewMessage{EGM: *message}}) - ta.connection.Send(data) - } + // Because we have set the sync capability any new messages that arrive after this point will just + // need to do a basic lookup from the last seen message + newMessages := ta.LegacyMessageStore.FetchMessagesFrom(lastSignature) + for _, message := range newMessages[len(messages):] { + data, _ = json.Marshal(groups.Message{MessageType: groups.NewMessageMessage, NewMessage: &groups.NewMessage{EGM: *message}}) + ta.connection.Send(data) } } else { log.Debugf("Server Closing Connection Because of Malformed ReplayRequestMessage Packet") diff --git a/storage/message_store.go b/storage/message_store.go index 66303b1..41e675f 100644 --- a/storage/message_store.go +++ b/storage/message_store.go @@ -13,15 +13,22 @@ import ( type MessageStoreInterface interface { AddMessage(groups.EncryptedGroupMessage) FetchMessages() []*groups.EncryptedGroupMessage + FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage } // SqliteMessageStore is an sqlite3 backed message store type SqliteMessageStore struct { database *sql.DB + + // Some prepared queries... + preparedInsertStatement *sql.Stmt // A Stmt is safe for concurrent use by multiple goroutines. + preparedFetchFromQuery *sql.Stmt } // Close closes the underlying sqlite3 database to further changes func (s *SqliteMessageStore) Close() { + s.preparedInsertStatement.Close() + s.preparedFetchFromQuery.Close() s.database.Close() } @@ -32,16 +39,9 @@ func (s *SqliteMessageStore) AddMessage(message groups.EncryptedGroupMessage) { log.Errorf("%q", err) return } - sqlStmt := `INSERT INTO messages(signature, ciphertext) values (?,?);` - stmt, err := s.database.Prepare(sqlStmt) + stmt, err := s.preparedInsertStatement.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext)) if err != nil { - log.Errorf("%q: %s", err, sqlStmt) - return - } - defer stmt.Close() - _, err = stmt.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext)) - if err != nil { - log.Errorf("%q: %s\n", err, sqlStmt) + log.Errorf("%v %q", stmt, err) return } tx.Commit() @@ -55,12 +55,31 @@ func (s SqliteMessageStore) FetchMessages() []*groups.EncryptedGroupMessage { return nil } defer rows.Close() + return s.compileRows(rows) +} + +// FetchMessagesFrom implements the MessageStoreInterface FetchMessagesFrom for sqlite message store +func (s SqliteMessageStore) FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage { + if signature == nil { + return s.FetchMessages() + } + + rows, err := s.preparedFetchFromQuery.Query(base64.StdEncoding.EncodeToString(signature)) + if err != nil { + log.Errorf("%v", err) + return nil + } + defer rows.Close() + return s.compileRows(rows) +} + +func (s *SqliteMessageStore) compileRows(rows *sql.Rows) []*groups.EncryptedGroupMessage { var messages []*groups.EncryptedGroupMessage for rows.Next() { var id int var signature string var ciphertext string - err = rows.Scan(&id, &signature, &ciphertext) + err := rows.Scan(&id, &signature, &ciphertext) if err != nil { log.Errorf("Error fetching row %v", err) } @@ -92,5 +111,21 @@ func InitializeSqliteMessageStore(dbfile string) (*SqliteMessageStore, error) { log.Infof("Database Initialized") slms := new(SqliteMessageStore) slms.database = db + + sqlStmt = `INSERT INTO messages(signature, ciphertext) values (?,?);` + stmt, err := slms.database.Prepare(sqlStmt) + if err != nil { + log.Errorf("%q: %s", err, sqlStmt) + return nil, fmt.Errorf("%s: %q", sqlStmt, err) + } + slms.preparedInsertStatement = stmt + + query, err := slms.database.Prepare("SELECT id, signature,ciphertext FROM messages WHERE id>=(SELECT id FROM messages WHERE signature=(?));") + if err != nil { + log.Errorf("%v", err) + return nil, fmt.Errorf("%s: %q", sqlStmt, err) + } + slms.preparedFetchFromQuery = query + return slms, nil } diff --git a/storage/message_store_test.go b/storage/message_store_test.go index dd5bd9a..eb343a0 100644 --- a/storage/message_store_test.go +++ b/storage/message_store_test.go @@ -2,6 +2,7 @@ package storage import ( "cwtch.im/cwtch/protocol/groups" + "encoding/binary" "git.openprivacy.ca/openprivacy/log" "os" "testing" @@ -15,22 +16,42 @@ func TestMessageStore(t *testing.T) { t.Fatalf("Error: %v", err) } - db.AddMessage(groups.EncryptedGroupMessage{ - Signature: []byte("Hello world 2"), - Ciphertext: []byte("Hello world"), - }) + numMessages := 100 - db.AddMessage(groups.EncryptedGroupMessage{ - Signature: []byte("Hello world 1"), - Ciphertext: []byte("Hello world"), - }) + t.Logf("Populating Database") + for i := 0; i < numMessages; i++ { + buf := make([]byte, 4) + binary.PutUvarint(buf, uint64(i)) + db.AddMessage(groups.EncryptedGroupMessage{ + Signature: append([]byte("Hello world"), buf...), + Ciphertext: []byte("Hello world"), + }) + t.Logf("Inserted %v", i) + } + // Wait for inserts to complete.. messages := db.FetchMessages() for _, message := range messages { t.Logf("Message: %v", message) } - if len(messages) != 2 { + if len(messages) != numMessages { t.Fatalf("Incorrect number of messages returned") } + + t.Logf("Testing FetchMessagesFrom...") + + numToFetch := numMessages / 2 + + buf := make([]byte, 4) + binary.PutUvarint(buf, uint64(numToFetch)) + sig := append([]byte("Hello world"), buf...) + messages = db.FetchMessagesFrom(sig) + for _, message := range messages { + t.Logf("Message: %v", message) + } + if len(messages) != numToFetch { + t.Fatalf("Incorrect number of messages returned : %v", len(messages)) + } + db.Close() }