Add FetchMessagesFrom
This commit is contained in:
parent
32d81046cf
commit
cc4d5ee428
|
@ -2,8 +2,8 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cwtch.im/cwtch/protocol/groups"
|
"cwtch.im/cwtch/protocol/groups"
|
||||||
"cwtch.im/cwtch/server/storage"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"git.openprivacy.ca/cwtch.im/server/storage"
|
||||||
"git.openprivacy.ca/cwtch.im/tapir"
|
"git.openprivacy.ca/cwtch.im/tapir"
|
||||||
"git.openprivacy.ca/cwtch.im/tapir/applications"
|
"git.openprivacy.ca/cwtch.im/tapir/applications"
|
||||||
"git.openprivacy.ca/cwtch.im/tapir/primitives/privacypass"
|
"git.openprivacy.ca/cwtch.im/tapir/primitives/privacypass"
|
||||||
|
@ -76,23 +76,25 @@ func (ta *TokenboardServer) Listen() {
|
||||||
case groups.ReplayRequestMessage:
|
case groups.ReplayRequestMessage:
|
||||||
if message.ReplayRequest != nil {
|
if message.ReplayRequest != nil {
|
||||||
log.Debugf("Received Replay Request %v", message.ReplayRequest)
|
log.Debugf("Received Replay Request %v", message.ReplayRequest)
|
||||||
messages := ta.LegacyMessageStore.FetchMessages()
|
messages := ta.LegacyMessageStore.FetchMessagesFrom(message.ReplayRequest.LastCommit)
|
||||||
response, _ := json.Marshal(groups.Message{MessageType: groups.ReplayResultMessage, ReplayResult: &groups.ReplayResult{NumMessages: len(messages)}})
|
response, _ := json.Marshal(groups.Message{MessageType: groups.ReplayResultMessage, ReplayResult: &groups.ReplayResult{NumMessages: len(messages)}})
|
||||||
log.Debugf("Sending Replay Response %v", groups.ReplayResult{NumMessages: len(messages)})
|
log.Debugf("Sending Replay Response %v", groups.ReplayResult{NumMessages: len(messages)})
|
||||||
ta.connection.Send(response)
|
ta.connection.Send(response)
|
||||||
|
lastSignature := message.ReplayRequest.LastCommit
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
|
lastSignature = message.Signature
|
||||||
data, _ = json.Marshal(message)
|
data, _ = json.Marshal(message)
|
||||||
ta.connection.Send(data)
|
ta.connection.Send(data)
|
||||||
}
|
}
|
||||||
log.Debugf("Finished Requested Sync")
|
log.Debugf("Finished Requested Sync")
|
||||||
// Set sync and then send any new messages that might have happened while we were syncing
|
// Set sync and then send any new messages that might have happened while we were syncing
|
||||||
ta.connection.SetCapability(groups.CwtchServerSyncedCapability)
|
ta.connection.SetCapability(groups.CwtchServerSyncedCapability)
|
||||||
newMessages := ta.LegacyMessageStore.FetchMessages()
|
// Because we have set the sync capability any new messages that arrive after this point will just
|
||||||
if len(newMessages) > len(messages) {
|
// need to do a basic lookup from the last seen message
|
||||||
for _, message := range newMessages[len(messages):] {
|
newMessages := ta.LegacyMessageStore.FetchMessagesFrom(lastSignature)
|
||||||
data, _ = json.Marshal(groups.Message{MessageType: groups.NewMessageMessage, NewMessage: &groups.NewMessage{EGM: *message}})
|
for _, message := range newMessages[len(messages):] {
|
||||||
ta.connection.Send(data)
|
data, _ = json.Marshal(groups.Message{MessageType: groups.NewMessageMessage, NewMessage: &groups.NewMessage{EGM: *message}})
|
||||||
}
|
ta.connection.Send(data)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Server Closing Connection Because of Malformed ReplayRequestMessage Packet")
|
log.Debugf("Server Closing Connection Because of Malformed ReplayRequestMessage Packet")
|
||||||
|
|
|
@ -13,15 +13,22 @@ import (
|
||||||
type MessageStoreInterface interface {
|
type MessageStoreInterface interface {
|
||||||
AddMessage(groups.EncryptedGroupMessage)
|
AddMessage(groups.EncryptedGroupMessage)
|
||||||
FetchMessages() []*groups.EncryptedGroupMessage
|
FetchMessages() []*groups.EncryptedGroupMessage
|
||||||
|
FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
// SqliteMessageStore is an sqlite3 backed message store
|
// SqliteMessageStore is an sqlite3 backed message store
|
||||||
type SqliteMessageStore struct {
|
type SqliteMessageStore struct {
|
||||||
database *sql.DB
|
database *sql.DB
|
||||||
|
|
||||||
|
// Some prepared queries...
|
||||||
|
preparedInsertStatement *sql.Stmt // A Stmt is safe for concurrent use by multiple goroutines.
|
||||||
|
preparedFetchFromQuery *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the underlying sqlite3 database to further changes
|
// Close closes the underlying sqlite3 database to further changes
|
||||||
func (s *SqliteMessageStore) Close() {
|
func (s *SqliteMessageStore) Close() {
|
||||||
|
s.preparedInsertStatement.Close()
|
||||||
|
s.preparedFetchFromQuery.Close()
|
||||||
s.database.Close()
|
s.database.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,16 +39,9 @@ func (s *SqliteMessageStore) AddMessage(message groups.EncryptedGroupMessage) {
|
||||||
log.Errorf("%q", err)
|
log.Errorf("%q", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sqlStmt := `INSERT INTO messages(signature, ciphertext) values (?,?);`
|
stmt, err := s.preparedInsertStatement.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext))
|
||||||
stmt, err := s.database.Prepare(sqlStmt)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("%q: %s", err, sqlStmt)
|
log.Errorf("%v %q", stmt, err)
|
||||||
return
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
_, err = stmt.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext))
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("%q: %s\n", err, sqlStmt)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
|
@ -55,12 +55,31 @@ func (s SqliteMessageStore) FetchMessages() []*groups.EncryptedGroupMessage {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
return s.compileRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchMessagesFrom implements the MessageStoreInterface FetchMessagesFrom for sqlite message store
|
||||||
|
func (s SqliteMessageStore) FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage {
|
||||||
|
if signature == nil {
|
||||||
|
return s.FetchMessages()
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.preparedFetchFromQuery.Query(base64.StdEncoding.EncodeToString(signature))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return s.compileRows(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteMessageStore) compileRows(rows *sql.Rows) []*groups.EncryptedGroupMessage {
|
||||||
var messages []*groups.EncryptedGroupMessage
|
var messages []*groups.EncryptedGroupMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var id int
|
var id int
|
||||||
var signature string
|
var signature string
|
||||||
var ciphertext string
|
var ciphertext string
|
||||||
err = rows.Scan(&id, &signature, &ciphertext)
|
err := rows.Scan(&id, &signature, &ciphertext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error fetching row %v", err)
|
log.Errorf("Error fetching row %v", err)
|
||||||
}
|
}
|
||||||
|
@ -92,5 +111,21 @@ func InitializeSqliteMessageStore(dbfile string) (*SqliteMessageStore, error) {
|
||||||
log.Infof("Database Initialized")
|
log.Infof("Database Initialized")
|
||||||
slms := new(SqliteMessageStore)
|
slms := new(SqliteMessageStore)
|
||||||
slms.database = db
|
slms.database = db
|
||||||
|
|
||||||
|
sqlStmt = `INSERT INTO messages(signature, ciphertext) values (?,?);`
|
||||||
|
stmt, err := slms.database.Prepare(sqlStmt)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%q: %s", err, sqlStmt)
|
||||||
|
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
||||||
|
}
|
||||||
|
slms.preparedInsertStatement = stmt
|
||||||
|
|
||||||
|
query, err := slms.database.Prepare("SELECT id, signature,ciphertext FROM messages WHERE id>=(SELECT id FROM messages WHERE signature=(?));")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
||||||
|
}
|
||||||
|
slms.preparedFetchFromQuery = query
|
||||||
|
|
||||||
return slms, nil
|
return slms, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cwtch.im/cwtch/protocol/groups"
|
"cwtch.im/cwtch/protocol/groups"
|
||||||
|
"encoding/binary"
|
||||||
"git.openprivacy.ca/openprivacy/log"
|
"git.openprivacy.ca/openprivacy/log"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -15,22 +16,42 @@ func TestMessageStore(t *testing.T) {
|
||||||
t.Fatalf("Error: %v", err)
|
t.Fatalf("Error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddMessage(groups.EncryptedGroupMessage{
|
numMessages := 100
|
||||||
Signature: []byte("Hello world 2"),
|
|
||||||
Ciphertext: []byte("Hello world"),
|
|
||||||
})
|
|
||||||
|
|
||||||
db.AddMessage(groups.EncryptedGroupMessage{
|
t.Logf("Populating Database")
|
||||||
Signature: []byte("Hello world 1"),
|
for i := 0; i < numMessages; i++ {
|
||||||
Ciphertext: []byte("Hello world"),
|
buf := make([]byte, 4)
|
||||||
})
|
binary.PutUvarint(buf, uint64(i))
|
||||||
|
db.AddMessage(groups.EncryptedGroupMessage{
|
||||||
|
Signature: append([]byte("Hello world"), buf...),
|
||||||
|
Ciphertext: []byte("Hello world"),
|
||||||
|
})
|
||||||
|
t.Logf("Inserted %v", i)
|
||||||
|
}
|
||||||
|
// Wait for inserts to complete..
|
||||||
|
|
||||||
messages := db.FetchMessages()
|
messages := db.FetchMessages()
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
t.Logf("Message: %v", message)
|
t.Logf("Message: %v", message)
|
||||||
}
|
}
|
||||||
if len(messages) != 2 {
|
if len(messages) != numMessages {
|
||||||
t.Fatalf("Incorrect number of messages returned")
|
t.Fatalf("Incorrect number of messages returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Logf("Testing FetchMessagesFrom...")
|
||||||
|
|
||||||
|
numToFetch := numMessages / 2
|
||||||
|
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
binary.PutUvarint(buf, uint64(numToFetch))
|
||||||
|
sig := append([]byte("Hello world"), buf...)
|
||||||
|
messages = db.FetchMessagesFrom(sig)
|
||||||
|
for _, message := range messages {
|
||||||
|
t.Logf("Message: %v", message)
|
||||||
|
}
|
||||||
|
if len(messages) != numToFetch {
|
||||||
|
t.Fatalf("Incorrect number of messages returned : %v", len(messages))
|
||||||
|
}
|
||||||
|
|
||||||
db.Close()
|
db.Close()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue