234 lines
6.6 KiB
Go
234 lines
6.6 KiB
Go
package storage
|
|
|
|
import (
|
|
"cwtch.im/cwtch/protocol/groups"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"git.openprivacy.ca/openprivacy/log"
|
|
"sync"
|
|
)
|
|
|
|
// MessageStoreInterface defines an interface to interact with a store of cwtch messages.
|
|
type MessageStoreInterface interface {
|
|
AddMessage(groups.EncryptedGroupMessage)
|
|
FetchMessages() []*groups.EncryptedGroupMessage
|
|
MessagesCount() int
|
|
FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage
|
|
SetMessageCap(newcap int)
|
|
Close()
|
|
}
|
|
|
|
// SqliteMessageStore is an sqlite3 backed message store
|
|
type SqliteMessageStore struct {
|
|
incMessageCounterFn func()
|
|
messageCap int
|
|
|
|
messageCount int
|
|
countLock sync.Mutex
|
|
|
|
database *sql.DB
|
|
|
|
// Some prepared queries...
|
|
preparedInsertStatement *sql.Stmt // A Stmt is safe for concurrent use by multiple goroutines.
|
|
preparedFetchFromQuery *sql.Stmt
|
|
preparedFetchQuery *sql.Stmt
|
|
preparedCountQuery *sql.Stmt
|
|
preparedPruneStatement *sql.Stmt
|
|
}
|
|
|
|
// Close closes the underlying sqlite3 database to further changes
|
|
func (s *SqliteMessageStore) Close() {
|
|
s.preparedInsertStatement.Close()
|
|
s.preparedFetchFromQuery.Close()
|
|
s.database.Close()
|
|
}
|
|
|
|
func (s *SqliteMessageStore) SetMessageCap(newcap int) {
|
|
s.countLock.Lock()
|
|
defer s.countLock.Unlock()
|
|
s.messageCap = newcap
|
|
s.checkPruneMessages()
|
|
}
|
|
|
|
// AddMessage implements the MessageStoreInterface AddMessage for sqlite message store
|
|
func (s *SqliteMessageStore) AddMessage(message groups.EncryptedGroupMessage) {
|
|
if s.incMessageCounterFn != nil {
|
|
s.incMessageCounterFn()
|
|
}
|
|
// ignore this clearly invalid message...
|
|
if len(message.Signature) == 0 {
|
|
return
|
|
}
|
|
|
|
stmt, err := s.preparedInsertStatement.Exec(base64.StdEncoding.EncodeToString(message.Signature), base64.StdEncoding.EncodeToString(message.Ciphertext))
|
|
if err != nil {
|
|
log.Errorf("%v %q", stmt, err)
|
|
return
|
|
}
|
|
|
|
s.countLock.Lock()
|
|
defer s.countLock.Unlock()
|
|
s.messageCount++
|
|
s.checkPruneMessages()
|
|
}
|
|
|
|
func (s *SqliteMessageStore) checkPruneMessages() {
|
|
if s.messageCap != -1 && s.messageCount > s.messageCap {
|
|
log.Debugf("Message Count: %d / Message Cap: %d, message cap exceeded, pruning oldest 10%...", s.messageCount, s.messageCap)
|
|
// Delete 10% of messages
|
|
delCount := s.messageCap / 10
|
|
stmt, err := s.preparedPruneStatement.Exec(s.messageCap / 10)
|
|
if err != nil {
|
|
log.Errorf("%v %q", stmt, err)
|
|
}
|
|
s.messageCount -= delCount
|
|
}
|
|
}
|
|
|
|
func (s *SqliteMessageStore) MessagesCount() int {
|
|
rows, err := s.preparedCountQuery.Query()
|
|
|
|
if err != nil {
|
|
log.Errorf("%v", err)
|
|
return -1
|
|
}
|
|
defer rows.Close()
|
|
|
|
result := rows.Next()
|
|
if !result {
|
|
return -1
|
|
}
|
|
|
|
var rownum int
|
|
err = rows.Scan(&rownum)
|
|
if err != nil {
|
|
log.Errorf("error fetching rows: %v", err)
|
|
return -1
|
|
}
|
|
|
|
return rownum
|
|
}
|
|
|
|
// FetchMessages implements the MessageStoreInterface FetchMessages for sqlite message store
|
|
func (s *SqliteMessageStore) FetchMessages() []*groups.EncryptedGroupMessage {
|
|
rows, err := s.preparedFetchQuery.Query()
|
|
if err != nil {
|
|
log.Errorf("%v", err)
|
|
return nil
|
|
}
|
|
defer rows.Close()
|
|
return s.compileRows(rows)
|
|
}
|
|
|
|
// FetchMessagesFrom implements the MessageStoreInterface FetchMessagesFrom for sqlite message store
|
|
func (s *SqliteMessageStore) FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage {
|
|
|
|
// If signature is empty then treat this as a complete sync request
|
|
if len(signature) == 0 {
|
|
return s.FetchMessages()
|
|
}
|
|
|
|
rows, err := s.preparedFetchFromQuery.Query(base64.StdEncoding.EncodeToString(signature))
|
|
if err != nil {
|
|
log.Errorf("%v", err)
|
|
return nil
|
|
}
|
|
defer rows.Close()
|
|
messages := s.compileRows(rows)
|
|
|
|
// if we don't have *any* messages then either the signature next existed
|
|
// or the server purged it...either way treat this as a full sync...
|
|
if len(messages) < 1 {
|
|
return s.FetchMessages()
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
func (s *SqliteMessageStore) compileRows(rows *sql.Rows) []*groups.EncryptedGroupMessage {
|
|
var messages []*groups.EncryptedGroupMessage
|
|
for rows.Next() {
|
|
var id int
|
|
var signature string
|
|
var ciphertext string
|
|
err := rows.Scan(&id, &signature, &ciphertext)
|
|
if err != nil {
|
|
log.Errorf("Error fetching row %v", err)
|
|
}
|
|
rawSignature, _ := base64.StdEncoding.DecodeString(signature)
|
|
rawCiphertext, _ := base64.StdEncoding.DecodeString(ciphertext)
|
|
messages = append(messages, &groups.EncryptedGroupMessage{
|
|
Signature: rawSignature,
|
|
Ciphertext: rawCiphertext,
|
|
})
|
|
}
|
|
return messages
|
|
}
|
|
|
|
// InitializeSqliteMessageStore creates a database `dbfile` with the necessary tables (if it doesn't already exist)
|
|
// and returns an open database
|
|
func InitializeSqliteMessageStore(dbfile string, messageCap int, incMessageCounterFn func()) (*SqliteMessageStore, error) {
|
|
db, err := sql.Open("sqlite3", dbfile)
|
|
if err != nil {
|
|
log.Errorf("database %v cannot be created or opened %v", dbfile, err)
|
|
return nil, fmt.Errorf("database %v cannot be created or opened: %v", dbfile, err)
|
|
}
|
|
sqlStmt := `CREATE TABLE IF NOT EXISTS messages (id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, signature TEXT UNIQUE NOT NULL, ciphertext TEXT NOT NULL);`
|
|
_, err = db.Exec(sqlStmt)
|
|
if err != nil {
|
|
db.Close()
|
|
log.Errorf("%q: %s", err, sqlStmt)
|
|
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
|
}
|
|
log.Infof("Database Initialized")
|
|
slms := new(SqliteMessageStore)
|
|
slms.database = db
|
|
slms.incMessageCounterFn = incMessageCounterFn
|
|
slms.messageCap = messageCap
|
|
|
|
sqlStmt = `INSERT INTO messages(signature, ciphertext) values (?,?);`
|
|
stmt, err := slms.database.Prepare(sqlStmt)
|
|
if err != nil {
|
|
log.Errorf("%q: %s", err, sqlStmt)
|
|
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
|
}
|
|
slms.preparedInsertStatement = stmt
|
|
|
|
sqlStmt = "SELECT id, signature,ciphertext from messages"
|
|
query, err := slms.database.Prepare(sqlStmt)
|
|
if err != nil {
|
|
log.Errorf("%q: %s", err, sqlStmt)
|
|
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
|
}
|
|
slms.preparedFetchQuery = query
|
|
|
|
sqlStmt = "SELECT id, signature,ciphertext FROM messages WHERE id>=(SELECT id FROM messages WHERE signature=(?));"
|
|
query, err = slms.database.Prepare(sqlStmt)
|
|
if err != nil {
|
|
log.Errorf("%q: %s", err, sqlStmt)
|
|
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
|
}
|
|
slms.preparedFetchFromQuery = query
|
|
|
|
sqlStmt = "SELECT COUNT(*) from messages"
|
|
stmt, err = slms.database.Prepare(sqlStmt)
|
|
if err != nil {
|
|
log.Errorf("%q: %s", err, sqlStmt)
|
|
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
|
}
|
|
slms.preparedCountQuery = stmt
|
|
|
|
sqlStmt = "DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT (?))"
|
|
stmt, err = slms.database.Prepare(sqlStmt)
|
|
if err != nil {
|
|
log.Errorf("%q: %s", err, sqlStmt)
|
|
return nil, fmt.Errorf("%s: %q", sqlStmt, err)
|
|
}
|
|
slms.preparedPruneStatement = stmt
|
|
|
|
slms.messageCount = slms.MessagesCount()
|
|
|
|
return slms, nil
|
|
}
|