package storage
import (
"cwtch.im/cwtch/protocol/groups"
"database/sql"
"encoding/base64"
"fmt"
"git.openprivacy.ca/openprivacy/log"
"sync"
)
// MessageStoreInterface defines an interface to interact with a store of cwtch messages.
type MessageStoreInterface interface {
AddMessage ( groups . EncryptedGroupMessage )
FetchMessages ( ) [ ] * groups . EncryptedGroupMessage
MessagesCount ( ) int
FetchMessagesFrom ( signature [ ] byte ) [ ] * groups . EncryptedGroupMessage
SetMessageCap ( newcap int )
Close ( )
}
// SqliteMessageStore is an sqlite3 backed message store
type SqliteMessageStore struct {
incMessageCounterFn func ( )
messageCap int
messageCount int
countLock sync . Mutex
database * sql . DB
// Some prepared queries...
preparedInsertStatement * sql . Stmt // A Stmt is safe for concurrent use by multiple goroutines.
preparedFetchFromQuery * sql . Stmt
preparedFetchQuery * sql . Stmt
preparedCountQuery * sql . Stmt
preparedPruneStatement * sql . Stmt
}
// Close closes the underlying sqlite3 database to further changes
func ( s * SqliteMessageStore ) Close ( ) {
s . preparedInsertStatement . Close ( )
s . preparedFetchFromQuery . Close ( )
s . database . Close ( )
}
func ( s * SqliteMessageStore ) SetMessageCap ( newcap int ) {
s . countLock . Lock ( )
defer s . countLock . Unlock ( )
s . messageCap = newcap
s . checkPruneMessages ( )
}
// AddMessage implements the MessageStoreInterface AddMessage for sqlite message store
func ( s * SqliteMessageStore ) AddMessage ( message groups . EncryptedGroupMessage ) {
if s . incMessageCounterFn != nil {
s . incMessageCounterFn ( )
}
// ignore this clearly invalid message...
if len ( message . Signature ) == 0 {
return
}
stmt , err := s . preparedInsertStatement . Exec ( base64 . StdEncoding . EncodeToString ( message . Signature ) , base64 . StdEncoding . EncodeToString ( message . Ciphertext ) )
if err != nil {
log . Errorf ( "%v %q" , stmt , err )
return
}
s . countLock . Lock ( )
defer s . countLock . Unlock ( )
s . messageCount ++
s . checkPruneMessages ( )
}
func ( s * SqliteMessageStore ) checkPruneMessages ( ) {
if s . messageCap != - 1 && s . messageCount > s . messageCap {
log . Debugf ( "Message Count: %d / Message Cap: %d, message cap exceeded, pruning oldest 10%..." , s . messageCount , s . messageCap )
// Delete 10% of messages
delCount := s . messageCap / 10
stmt , err := s . preparedPruneStatement . Exec ( s . messageCap / 10 )
if err != nil {
log . Errorf ( "%v %q" , stmt , err )
}
s . messageCount -= delCount
}
}
func ( s * SqliteMessageStore ) MessagesCount ( ) int {
rows , err := s . preparedCountQuery . Query ( )
if err != nil {
log . Errorf ( "%v" , err )
return - 1
}
defer rows . Close ( )
result := rows . Next ( )
if ! result {
return - 1
}
var rownum int
err = rows . Scan ( & rownum )
if err != nil {
log . Errorf ( "error fetching rows: %v" , err )
return - 1
}
return rownum
}
// FetchMessages implements the MessageStoreInterface FetchMessages for sqlite message store
func ( s * SqliteMessageStore ) FetchMessages ( ) [ ] * groups . EncryptedGroupMessage {
rows , err := s . preparedFetchQuery . Query ( )
if err != nil {
log . Errorf ( "%v" , err )
return nil
}
defer rows . Close ( )
return s . compileRows ( rows )
}
// FetchMessagesFrom implements the MessageStoreInterface FetchMessagesFrom for sqlite message store
func ( s * SqliteMessageStore ) FetchMessagesFrom ( signature [ ] byte ) [ ] * groups . EncryptedGroupMessage {
// If signature is empty then treat this as a complete sync request
if len ( signature ) == 0 {
return s . FetchMessages ( )
}
rows , err := s . preparedFetchFromQuery . Query ( base64 . StdEncoding . EncodeToString ( signature ) )
if err != nil {
log . Errorf ( "%v" , err )
return nil
}
defer rows . Close ( )
messages := s . compileRows ( rows )
// if we don't have *any* messages then either the signature next existed
// or the server purged it...either way treat this as a full sync...
if len ( messages ) < 1 {
return s . FetchMessages ( )
}
return messages
}
func ( s * SqliteMessageStore ) compileRows ( rows * sql . Rows ) [ ] * groups . EncryptedGroupMessage {
var messages [ ] * groups . EncryptedGroupMessage
for rows . Next ( ) {
var id int
var signature string
var ciphertext string
err := rows . Scan ( & id , & signature , & ciphertext )
if err != nil {
log . Errorf ( "Error fetching row %v" , err )
}
rawSignature , _ := base64 . StdEncoding . DecodeString ( signature )
rawCiphertext , _ := base64 . StdEncoding . DecodeString ( ciphertext )
messages = append ( messages , & groups . EncryptedGroupMessage {
Signature : rawSignature ,
Ciphertext : rawCiphertext ,
} )
}
return messages
}
// InitializeSqliteMessageStore creates a database `dbfile` with the necessary tables (if it doesn't already exist)
// and returns an open database
func InitializeSqliteMessageStore ( dbfile string , messageCap int , incMessageCounterFn func ( ) ) ( * SqliteMessageStore , error ) {
db , err := sql . Open ( "sqlite3" , dbfile )
if err != nil {
log . Errorf ( "database %v cannot be created or opened %v" , dbfile , err )
return nil , fmt . Errorf ( "database %v cannot be created or opened: %v" , dbfile , err )
}
sqlStmt := ` CREATE TABLE IF NOT EXISTS messages (id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, signature TEXT UNIQUE NOT NULL, ciphertext TEXT NOT NULL); `
_ , err = db . Exec ( sqlStmt )
if err != nil {
db . Close ( )
log . Errorf ( "%q: %s" , err , sqlStmt )
return nil , fmt . Errorf ( "%s: %q" , sqlStmt , err )
}
log . Infof ( "Database Initialized" )
slms := new ( SqliteMessageStore )
slms . database = db
slms . incMessageCounterFn = incMessageCounterFn
slms . messageCap = messageCap
sqlStmt = ` INSERT INTO messages(signature, ciphertext) values (?,?); `
stmt , err := slms . database . Prepare ( sqlStmt )
if err != nil {
log . Errorf ( "%q: %s" , err , sqlStmt )
return nil , fmt . Errorf ( "%s: %q" , sqlStmt , err )
}
slms . preparedInsertStatement = stmt
sqlStmt = "SELECT id, signature,ciphertext from messages"
query , err := slms . database . Prepare ( sqlStmt )
if err != nil {
log . Errorf ( "%q: %s" , err , sqlStmt )
return nil , fmt . Errorf ( "%s: %q" , sqlStmt , err )
}
slms . preparedFetchQuery = query
sqlStmt = "SELECT id, signature,ciphertext FROM messages WHERE id>=(SELECT id FROM messages WHERE signature=(?));"
query , err = slms . database . Prepare ( sqlStmt )
if err != nil {
log . Errorf ( "%q: %s" , err , sqlStmt )
return nil , fmt . Errorf ( "%s: %q" , sqlStmt , err )
}
slms . preparedFetchFromQuery = query
sqlStmt = "SELECT COUNT(*) from messages"
stmt , err = slms . database . Prepare ( sqlStmt )
if err != nil {
log . Errorf ( "%q: %s" , err , sqlStmt )
return nil , fmt . Errorf ( "%s: %q" , sqlStmt , err )
}
slms . preparedCountQuery = stmt
sqlStmt = "DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT (?))"
stmt , err = slms . database . Prepare ( sqlStmt )
if err != nil {
log . Errorf ( "%q: %s" , err , sqlStmt )
return nil , fmt . Errorf ( "%s: %q" , sqlStmt , err )
}
slms . preparedPruneStatement = stmt
slms . messageCount = slms . MessagesCount ( )
return slms , nil
}