Add FetchMessagesFrom
This commit is contained in:
parent
32d81046cf
commit
cc4d5ee428
|
@ -2,8 +2,8 @@ package server
|
|||
|
||||
import (
|
||||
"cwtch.im/cwtch/protocol/groups"
|
||||
"cwtch.im/cwtch/server/storage"
|
||||
"encoding/json"
|
||||
"git.openprivacy.ca/cwtch.im/server/storage"
|
||||
"git.openprivacy.ca/cwtch.im/tapir"
|
||||
"git.openprivacy.ca/cwtch.im/tapir/applications"
|
||||
"git.openprivacy.ca/cwtch.im/tapir/primitives/privacypass"
|
||||
|
@ -76,23 +76,25 @@ func (ta *TokenboardServer) Listen() {
|
|||
case groups.ReplayRequestMessage:
|
||||
if message.ReplayRequest != nil {
|
||||
log.Debugf("Received Replay Request %v", message.ReplayRequest)
|
||||
messages := ta.LegacyMessageStore.FetchMessages()
|
||||
messages := ta.LegacyMessageStore.FetchMessagesFrom(message.ReplayRequest.LastCommit)
|
||||
response, _ := json.Marshal(groups.Message{MessageType: groups.ReplayResultMessage, ReplayResult: &groups.ReplayResult{NumMessages: len(messages)}})
|
||||
log.Debugf("Sending Replay Response %v", groups.ReplayResult{NumMessages: len(messages)})
|
||||
ta.connection.Send(response)
|
||||
lastSignature := message.ReplayRequest.LastCommit
|
||||
for _, message := range messages {
|
||||
lastSignature = message.Signature
|
||||
data, _ = json.Marshal(message)
|
||||
ta.connection.Send(data)
|
||||
}
|
||||
log.Debugf("Finished Requested Sync")
|
||||
// Set sync and then send any new messages that might have happened while we were syncing
|
||||
ta.connection.SetCapability(groups.CwtchServerSyncedCapability)
|
||||
newMessages := ta.LegacyMessageStore.FetchMessages()
|
||||
if len(newMessages) > len(messages) {
|
||||
for _, message := range newMessages[len(messages):] {
|
||||
data, _ = json.Marshal(groups.Message{MessageType: groups.NewMessageMessage, NewMessage: &groups.NewMessage{EGM: *message}})
|
||||
ta.connection.Send(data)
|
||||
}
|
||||
// Because we have set the sync capability any new messages that arrive after this point will just
|
||||
// need to do a basic lookup from the last seen message
|
||||
newMessages := ta.LegacyMessageStore.FetchMessagesFrom(lastSignature)
|
||||
for _, message := range newMessages[len(messages):] {
|
||||
data, _ = json.Marshal(groups.Message{MessageType: groups.NewMessageMessage, NewMessage: &groups.NewMessage{EGM: *message}})
|
||||
ta.connection.Send(data)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("Server Closing Connection Because of Malformed ReplayRequestMessage Packet")
|
||||
|
|
|
@ -13,15 +13,22 @@ import (
|
|||
type MessageStoreInterface interface {
|
||||
AddMessage(groups.EncryptedGroupMessage)
|
||||
FetchMessages() []*groups.EncryptedGroupMessage
|
||||
FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage
|
||||
}
|
||||
|
||||
// SqliteMessageStore is an sqlite3 backed message store
|
||||
type SqliteMessageStore struct {
|
||||
database *sql.DB
|
||||
|
||||
// Some prepared queries...
|
||||
preparedInsertStatement *sql.Stmt // A Stmt is safe for concurrent use by multiple goroutines.
|
||||
preparedFetchFromQuery *sql.Stmt
|
||||
}
|
||||
|
||||
// Close closes the underlying sqlite3 database to further changes
|
||||
func (s *SqliteMessageStore) Close() {
|
||||
s.preparedInsertStatement.Close()
|
||||
s.preparedFetchFromQuery.Close()
|
||||
s.database.Close()
|
||||
}
|
||||
|
||||
|
@ -32,16 +39,9 @@ func (s *SqliteMessageStore) AddMessage(message groups.EncryptedGroupMessage) {
|
|||
log.Errorf("%q", err)
|
||||
return
|
||||
}
|
||||
sqlStmt := `INSERT INTO messages(signature, ciphertext) values (?,?);`
|
||||
stmt, err := s.database.Prepare(sqlStmt)
|
||||
stmt, err := s.preparedInsertStatement.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext))
|
||||
if err != nil {
|
||||
log.Errorf("%q: %s", err, sqlStmt)
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext))
|
||||
if err != nil {
|
||||
log.Errorf("%q: %s\n", err, sqlStmt)
|
||||
log.Errorf("%v %q", stmt, err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
@ -55,12 +55,31 @@ func (s SqliteMessageStore) FetchMessages() []*groups.EncryptedGroupMessage {
|
|||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.compileRows(rows)
|
||||
}
|
||||
|
||||
// FetchMessagesFrom implements the MessageStoreInterface FetchMessagesFrom for sqlite message store
|
||||
func (s SqliteMessageStore) FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage {
|
||||
if signature == nil {
|
||||
return s.FetchMessages()
|
||||
}
|
||||
|
||||
rows, err := s.preparedFetchFromQuery.Query(base64.StdEncoding.EncodeToString(signature))
|
||||
if err != nil {
|
||||
log.Errorf("%v", err)
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.compileRows(rows)
|
||||
}
|
||||
|
||||
func (s *SqliteMessageStore) compileRows(rows *sql.Rows) []*groups.EncryptedGroupMessage {
|
||||
var messages []*groups.EncryptedGroupMessage
|
||||
for rows.Next() {
|
||||
var id int
|
||||
var signature string
|
||||
var ciphertext string
|
||||
err = rows.Scan(&id, &signature, &ciphertext)
|
||||
err := rows.Scan(&id, &signature, &ciphertext)
|
||||
if err != nil {
|
||||
log.Errorf("Error fetching row %v", err)
|
||||
}
|
||||
|
@ -92,5 +111,21 @@ func InitializeSqliteMessageStore(dbfile string) (*SqliteMessageStore, error) {
|
|||
log.Infof("Database Initialized")
|
||||
slms := new(SqliteMessageStore)
|
||||
slms.database = db
|
||||
|
||||
sqlStmt = `INSERT INTO messages(signature, ciphertext) values (?,?);`
|
||||
stmt, err := slms.database.Prepare(sqlStmt)
|
||||
if err != nil {
|
||||
log.Errorf("%q: %s", err, sqlStmt)
|
||||
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
||||
}
|
||||
slms.preparedInsertStatement = stmt
|
||||
|
||||
query, err := slms.database.Prepare("SELECT id, signature,ciphertext FROM messages WHERE id>=(SELECT id FROM messages WHERE signature=(?));")
|
||||
if err != nil {
|
||||
log.Errorf("%v", err)
|
||||
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
||||
}
|
||||
slms.preparedFetchFromQuery = query
|
||||
|
||||
return slms, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package storage
|
|||
|
||||
import (
|
||||
"cwtch.im/cwtch/protocol/groups"
|
||||
"encoding/binary"
|
||||
"git.openprivacy.ca/openprivacy/log"
|
||||
"os"
|
||||
"testing"
|
||||
|
@ -15,22 +16,42 @@ func TestMessageStore(t *testing.T) {
|
|||
t.Fatalf("Error: %v", err)
|
||||
}
|
||||
|
||||
db.AddMessage(groups.EncryptedGroupMessage{
|
||||
Signature: []byte("Hello world 2"),
|
||||
Ciphertext: []byte("Hello world"),
|
||||
})
|
||||
numMessages := 100
|
||||
|
||||
db.AddMessage(groups.EncryptedGroupMessage{
|
||||
Signature: []byte("Hello world 1"),
|
||||
Ciphertext: []byte("Hello world"),
|
||||
})
|
||||
t.Logf("Populating Database")
|
||||
for i := 0; i < numMessages; i++ {
|
||||
buf := make([]byte, 4)
|
||||
binary.PutUvarint(buf, uint64(i))
|
||||
db.AddMessage(groups.EncryptedGroupMessage{
|
||||
Signature: append([]byte("Hello world"), buf...),
|
||||
Ciphertext: []byte("Hello world"),
|
||||
})
|
||||
t.Logf("Inserted %v", i)
|
||||
}
|
||||
// Wait for inserts to complete..
|
||||
|
||||
messages := db.FetchMessages()
|
||||
for _, message := range messages {
|
||||
t.Logf("Message: %v", message)
|
||||
}
|
||||
if len(messages) != 2 {
|
||||
if len(messages) != numMessages {
|
||||
t.Fatalf("Incorrect number of messages returned")
|
||||
}
|
||||
|
||||
t.Logf("Testing FetchMessagesFrom...")
|
||||
|
||||
numToFetch := numMessages / 2
|
||||
|
||||
buf := make([]byte, 4)
|
||||
binary.PutUvarint(buf, uint64(numToFetch))
|
||||
sig := append([]byte("Hello world"), buf...)
|
||||
messages = db.FetchMessagesFrom(sig)
|
||||
for _, message := range messages {
|
||||
t.Logf("Message: %v", message)
|
||||
}
|
||||
if len(messages) != numToFetch {
|
||||
t.Fatalf("Incorrect number of messages returned : %v", len(messages))
|
||||
}
|
||||
|
||||
db.Close()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue