diff --git a/server.go b/server.go index 56e7247..854b2b8 100644 --- a/server.go +++ b/server.go @@ -113,7 +113,7 @@ func (s *server) Run(acn connectivity.ACN) error { } var err error - s.messageStore, err = storage.InitializeSqliteMessageStore(path.Join(s.config.ConfigDir, "cwtch.messages"), s.incMessageCount) + s.messageStore, err = storage.InitializeSqliteMessageStore(path.Join(s.config.ConfigDir, "cwtch.messages"), s.config.GetMaxMessages(), s.incMessageCount) if err != nil { return fmt.Errorf("could not open database: %v", err) } @@ -160,7 +160,7 @@ func (s *server) CheckStatus() (bool, error) { } // Stop turns off the server so it cannot receive connections and frees most resourses. -// The server is still in a reRunable state and tokenServer still has an active persistance +// The server is still in a reRunable state and tokenServer still has an active persistence func (s *server) Stop() { log.Infof("Shutting down server") s.lock.Lock() @@ -176,7 +176,7 @@ func (s *server) Stop() { } } -// Destroy frees the last of the resources the server has active (toklenServer persistance) leaving it un-re-runable and completely shutdown +// Destroy frees the last of the resources the server has active (toklenServer persistence) leaving it un-re-runable and completely shutdown func (s *server) Destroy() { s.Stop() s.lock.Lock() @@ -246,6 +246,17 @@ func (s *server) SetAttribute(key, val string) { s.config.SetAttribute(key, val) } +// GetMessageCap gets a server's MaxStorageMBs value +func (s *server) GetMaxStoreageMBs() int { + return s.config.GetMaxMessageMBs() +} + +// SetMaxStoreageMBs sets a server's MaxStoreageMBs and sets MaxMessages for storage (which can trigger a prune) +func (s *server) SetMaxStoreageMBs(val int) { + s.config.SetMaxMessageMBs(val) + s.messageStore.SetMessageCap(s.config.GetMaxMessages()) +} + // SetMonitorLogging turns on or off the monitor logging suite, and logging to a file in the server dir func (s *server) SetMonitorLogging(do bool) { s.config.ServerReporting.LogMetricsToFile = do diff --git a/serverConfig.go b/serverConfig.go index 6f9b124..eaee95b 100644 --- a/serverConfig.go +++ b/serverConfig.go @@ -42,6 +42,9 @@ type Reporting struct { LogMetricsToFile bool `json:"logMetricsToFile"` } +// messages are ~4kb of storage +const MessagesPerMB = 250 + // Config is a struct for storing basic server configuration type Config struct { ConfigDir string `json:"-"` @@ -62,6 +65,10 @@ type Config struct { Attributes map[string]string `json:"attributes"` + // messages are ~4kb of storage + // -1 == infinite + MaxStorageMBs int `json:"maxStorageMBs"` + lock sync.Mutex encFileStore storage.FileStore } @@ -90,6 +97,7 @@ func initDefaultConfig(configDir, filename string, encrypted bool) *Config { LogMetricsToFile: false, } config.Attributes[AttrAutostart] = "false" + config.MaxStorageMBs = -1 k := new(ristretto255.Scalar) b := make([]byte, 64) @@ -212,3 +220,26 @@ func (config *Config) GetAttribute(key string) string { defer config.lock.Unlock() return config.Attributes[key] } + +// GetMaxMessages returns the config setting for Max messages converting from MaxMB to messages +// or -1 for infinite +func (config *Config) GetMaxMessages() int { + config.lock.Lock() + defer config.lock.Unlock() + if config.MaxStorageMBs == -1 { + return -1 + } + return config.MaxStorageMBs * MessagesPerMB +} + +func (config *Config) GetMaxMessageMBs() int { + config.lock.Lock() + defer config.lock.Unlock() + return config.MaxStorageMBs +} + +func (config *Config) SetMaxMessageMBs(newval int) { + config.lock.Lock() + defer config.lock.Unlock() + config.MaxStorageMBs = newval +} diff --git a/storage/message_store.go b/storage/message_store.go index 4f90734..d393f99 100644 --- a/storage/message_store.go +++ b/storage/message_store.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "fmt" "git.openprivacy.ca/openprivacy/log" + "sync" ) // MessageStoreInterface defines an interface to interact with a store of cwtch messages. @@ -14,17 +15,26 @@ type MessageStoreInterface interface { FetchMessages() []*groups.EncryptedGroupMessage MessagesCount() int FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage + SetMessageCap(newcap int) Close() } // SqliteMessageStore is an sqlite3 backed message store type SqliteMessageStore struct { incMessageCounterFn func() - database *sql.DB + messageCap int + + messageCount int + countLock sync.Mutex + + database *sql.DB // Some prepared queries... preparedInsertStatement *sql.Stmt // A Stmt is safe for concurrent use by multiple goroutines. preparedFetchFromQuery *sql.Stmt + preparedFetchQuery *sql.Stmt + preparedCountQuery *sql.Stmt + preparedPruneStatement *sql.Stmt } // Close closes the underlying sqlite3 database to further changes @@ -34,6 +44,13 @@ func (s *SqliteMessageStore) Close() { s.database.Close() } +func (s *SqliteMessageStore) SetMessageCap(newcap int) { + s.countLock.Lock() + defer s.countLock.Unlock() + s.messageCap = newcap + s.checkPruneMessages() +} + // AddMessage implements the MessageStoreInterface AddMessage for sqlite message store func (s *SqliteMessageStore) AddMessage(message groups.EncryptedGroupMessage) { if s.incMessageCounterFn != nil { @@ -49,10 +66,29 @@ func (s *SqliteMessageStore) AddMessage(message groups.EncryptedGroupMessage) { log.Errorf("%v %q", stmt, err) return } + + s.countLock.Lock() + defer s.countLock.Unlock() + s.messageCount++ + s.checkPruneMessages() } -func (s SqliteMessageStore) MessagesCount() int { - rows, err := s.database.Query("SELECT COUNT(*) from messages") +func (s *SqliteMessageStore) checkPruneMessages() { + if s.messageCap != -1 && s.messageCount > s.messageCap { + log.Debugf("Message Count: %d / Message Cap: %d, message cap exceeded, pruning oldest 10%...", s.messageCount, s.messageCap) + // Delete 10% of messages + delCount := s.messageCap / 10 + stmt, err := s.preparedPruneStatement.Exec(s.messageCap / 10) + if err != nil { + log.Errorf("%v %q", stmt, err) + } + s.messageCount -= delCount + } +} + +func (s *SqliteMessageStore) MessagesCount() int { + rows, err := s.preparedCountQuery.Query() + if err != nil { log.Errorf("%v", err) return -1 @@ -75,8 +111,8 @@ func (s SqliteMessageStore) MessagesCount() int { } // FetchMessages implements the MessageStoreInterface FetchMessages for sqlite message store -func (s SqliteMessageStore) FetchMessages() []*groups.EncryptedGroupMessage { - rows, err := s.database.Query("SELECT id, signature,ciphertext from messages") +func (s *SqliteMessageStore) FetchMessages() []*groups.EncryptedGroupMessage { + rows, err := s.preparedFetchQuery.Query() if err != nil { log.Errorf("%v", err) return nil @@ -86,7 +122,7 @@ func (s SqliteMessageStore) FetchMessages() []*groups.EncryptedGroupMessage { } // FetchMessagesFrom implements the MessageStoreInterface FetchMessagesFrom for sqlite message store -func (s SqliteMessageStore) FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage { +func (s *SqliteMessageStore) FetchMessagesFrom(signature []byte) []*groups.EncryptedGroupMessage { // If signature is empty then treat this as a complete sync request if len(signature) == 0 { @@ -132,7 +168,7 @@ func (s *SqliteMessageStore) compileRows(rows *sql.Rows) []*groups.EncryptedGrou // InitializeSqliteMessageStore creates a database `dbfile` with the necessary tables (if it doesn't already exist) // and returns an open database -func InitializeSqliteMessageStore(dbfile string, incMessageCounterFn func()) (*SqliteMessageStore, error) { +func InitializeSqliteMessageStore(dbfile string, messageCap int, incMessageCounterFn func()) (*SqliteMessageStore, error) { db, err := sql.Open("sqlite3", dbfile) if err != nil { log.Errorf("database %v cannot be created or opened %v", dbfile, err) @@ -149,6 +185,7 @@ func InitializeSqliteMessageStore(dbfile string, incMessageCounterFn func()) (*S slms := new(SqliteMessageStore) slms.database = db slms.incMessageCounterFn = incMessageCounterFn + slms.messageCap = messageCap sqlStmt = `INSERT INTO messages(signature, ciphertext) values (?,?);` stmt, err := slms.database.Prepare(sqlStmt) @@ -158,12 +195,39 @@ func InitializeSqliteMessageStore(dbfile string, incMessageCounterFn func()) (*S } slms.preparedInsertStatement = stmt - query, err := slms.database.Prepare("SELECT id, signature,ciphertext FROM messages WHERE id>=(SELECT id FROM messages WHERE signature=(?));") + sqlStmt = "SELECT id, signature,ciphertext from messages" + query, err := slms.database.Prepare(sqlStmt) if err != nil { - log.Errorf("%v", err) + log.Errorf("%q: %s", err, sqlStmt) + return nil, fmt.Errorf("%s: %q", sqlStmt, err) + } + slms.preparedFetchQuery = query + + sqlStmt = "SELECT id, signature,ciphertext FROM messages WHERE id>=(SELECT id FROM messages WHERE signature=(?));" + query, err = slms.database.Prepare(sqlStmt) + if err != nil { + log.Errorf("%q: %s", err, sqlStmt) return nil, fmt.Errorf("%s: %q", sqlStmt, err) } slms.preparedFetchFromQuery = query + sqlStmt = "SELECT COUNT(*) from messages" + stmt, err = slms.database.Prepare(sqlStmt) + if err != nil { + log.Errorf("%q: %s", err, sqlStmt) + return nil, fmt.Errorf("%s: %q", sqlStmt, err) + } + slms.preparedCountQuery = stmt + + sqlStmt = "DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY id ASC LIMIT (?))" + stmt, err = slms.database.Prepare(sqlStmt) + if err != nil { + log.Errorf("%q: %s", err, sqlStmt) + return nil, fmt.Errorf("%s: %q", sqlStmt, err) + } + slms.preparedPruneStatement = stmt + + slms.messageCount = slms.MessagesCount() + return slms, nil } diff --git a/storage/message_store_test.go b/storage/message_store_test.go index 619632b..71a1a1a 100644 --- a/storage/message_store_test.go +++ b/storage/message_store_test.go @@ -16,7 +16,7 @@ func TestMessageStore(t *testing.T) { os.Remove(filename) log.SetLevel(log.LevelDebug) counter := metrics.NewCounter() - db, err := InitializeSqliteMessageStore(filename, func() { counter.Add(1) }) + db, err := InitializeSqliteMessageStore(filename, -1, func() { counter.Add(1) }) if err != nil { t.Fatalf("Error: %v", err) }