diff --git a/app/plugins/networkCheck.go b/app/plugins/networkCheck.go index 3aec64b..9b38eb0 100644 --- a/app/plugins/networkCheck.go +++ b/app/plugins/networkCheck.go @@ -1,10 +1,9 @@ package plugins import ( - "context" "cwtch.im/cwtch/event" "cwtch.im/cwtch/protocol/connections" - "fmt" + "cwtch.im/cwtch/utils" "git.openprivacy.ca/openprivacy/connectivity" "git.openprivacy.ca/openprivacy/log" "sync" @@ -131,7 +130,7 @@ func (nc *networkCheck) checkConnection(onion string) { // we want to definitively time these actions out faster than tor will, because these onions should definitely be // online - ClientTimeout := TimeoutPolicy(time.Second * 60) + ClientTimeout := utils.TimeoutPolicy(time.Second * 60) err := ClientTimeout.ExecuteAction(func() error { conn, _, err := nc.acn.Open(onion) if err == nil { @@ -152,30 +151,3 @@ func (nc *networkCheck) checkConnection(onion string) { nc.offline = false } } - -// TODO we might want to reuse this, but for now it is only used by this plugin so it can live here - -// TimeoutPolicy is an interface for enforcing common timeout patterns -type TimeoutPolicy time.Duration - -// ExecuteAction runs a function and returns an error if it hasn't returned -// by the time specified by TimeoutPolicy -func (tp *TimeoutPolicy) ExecuteAction(action func() error) error { - - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*tp)) - defer cancel() - - // channel is buffered- this is important! - c := make(chan error, 1) - go func() { - // this write is non-blocking as this goroutine has sole access to the channel - c <- action() - }() - - select { - case err := <-c: - return err - case <-ctx.Done(): - return fmt.Errorf("ActionTimedOutError") - } -} diff --git a/app/utils/utils.go b/app/utils.go similarity index 86% rename from app/utils/utils.go rename to app/utils.go index 66c5116..799a1f3 100644 --- a/app/utils/utils.go +++ b/app/utils.go @@ -1,7 +1,6 @@ -package utils +package app import ( - app2 "cwtch.im/cwtch/app" "cwtch.im/cwtch/model/attr" "cwtch.im/cwtch/model/constants" "cwtch.im/cwtch/peer" @@ -12,7 +11,7 @@ import ( // Proper use of an App is to call CreatePeer and then process the NewPeer event // however for small utility use, this function which polls the app until the peer is created // may fill that usecase better -func WaitGetPeer(app app2.Application, name string) peer.CwtchPeer { +func WaitGetPeer(app Application, name string) peer.CwtchPeer { for { for _, handle := range app.ListProfiles() { peer := app.GetPeer(handle) diff --git a/event/common.go b/event/common.go index d0ab3ad..e1430f6 100644 --- a/event/common.go +++ b/event/common.go @@ -196,6 +196,7 @@ const ( MessageCounterResync = Type("MessageCounterResync") // File Handling Events + StopFileShare = Type("StopFileShare") ShareManifest = Type("ShareManifest") ManifestSizeReceived = Type("ManifestSizeReceived") ManifestError = Type("ManifestError") diff --git a/functionality/filesharing/filesharing_functionality.go b/functionality/filesharing/filesharing_functionality.go index 0902ba6..e27e834 100644 --- a/functionality/filesharing/filesharing_functionality.go +++ b/functionality/filesharing/filesharing_functionality.go @@ -89,6 +89,61 @@ func (f *Functionality) DownloadFile(profile peer.CwtchPeer, conversationID int, profile.SendScopedZonedGetValToContact(conversationID, attr.ConversationScope, attr.FilesharingZone, fmt.Sprintf("%s.manifest.size", key)) } +// ReShareFiles given a profile we iterate through all existing fileshares and re-share them +// if the time limit has not expired +func (f *Functionality) ReShareFiles(profile peer.CwtchPeer) error { + keys, err := profile.GetScopedZonedAttributeKeys(attr.LocalScope, attr.FilesharingZone) + if err != nil { + return err + } + + for _, key := range keys { + // only look at timestamp keys + // this is an arbitrary choice + if strings.HasSuffix(key, ".ts") { + _, zonedpath := attr.ParseScope(key) + _, keypath := attr.ParseZone(zonedpath) + keyparts := strings.Split(keypath, ".") + + // assert that the key is well-formed + if len(keyparts) == 3 && keyparts[2] == "ts" { + // fetch the timestamp key + filekey := strings.Join(keyparts[:2], ".") + timestampString, tsExists := profile.GetScopedZonedAttribute(attr.LocalScope, attr.FilesharingZone, fmt.Sprintf("%s.ts", filekey)) + + // assert that the timestamp actually exists + if !tsExists { + log.Errorf("could not find expected timestamp for %v", filekey) + continue + } + + // assert this is an actual timestamp + timestamp, err := strconv.Atoi(timestampString) + if err != nil { + log.Errorf("error parsing timestamp for %v: %v", filekey, err) + continue + } + + dateShared := time.Unix(int64(timestamp), 0) + if time.Since(dateShared) > time.Hour*24*30 { + log.Debugf("ignored expired file share for %v", filekey) + continue + } + + // If fewer than 30 days have passed since we originally shared this file, + // then attempt to share this file again... + // TODO: In the future this would be the point to change the timestamp and reshare the file... + manifest, manifestExists := profile.GetScopedZonedAttribute(attr.ConversationScope, attr.FilesharingZone, fmt.Sprintf("%s.manifest", filekey)) + if manifestExists { + // everything is in order, so reshare this file with the engine + profile.ShareFile(filekey, manifest) + } + } + } + } + return nil +} + // ShareFile given a profile and a conversation handle, sets up a file sharing process to share the file // at filepath func (f *Functionality) ShareFile(filepath string, profile peer.CwtchPeer) (string, string, error) { diff --git a/peer/cwtch_peer.go b/peer/cwtch_peer.go index 4a84f2f..b1fd470 100644 --- a/peer/cwtch_peer.go +++ b/peer/cwtch_peer.go @@ -185,8 +185,22 @@ func (cp *cwtchPeer) GetScopedZonedAttribute(scope attr.Scope, zone attr.Zone, k return string(value), true } +// GetScopedZonedAttributes finds all keys associated with the given scope and zone +func (cp *cwtchPeer) GetScopedZonedAttributeKeys(scope attr.Scope, zone attr.Zone) ([]string, error) { + cp.mutex.Lock() + defer cp.mutex.Unlock() + scopedZonedKey := scope.ConstructScopedZonedPath(zone.ConstructZonedPath("")) + + keys, err := cp.storage.FindProfileKeysByPrefix(TypeAttribute, scopedZonedKey.ToString()) + + if err != nil { + return nil, err + } + + return keys, nil +} + // SetScopedZonedAttribute -// Status: Ready for 1.5 func (cp *cwtchPeer) SetScopedZonedAttribute(scope attr.Scope, zone attr.Zone, key string, value string) { cp.mutex.Lock() defer cp.mutex.Unlock() @@ -1016,18 +1030,26 @@ func (cp *cwtchPeer) storeMessage(handle string, message string, sent time.Time) } // ShareFile begins hosting the given serialized manifest -// Status: Ready for 1.5 func (cp *cwtchPeer) ShareFile(fileKey string, serializedManifest string) { tsStr, exists := cp.GetScopedZonedAttribute(attr.LocalScope, attr.FilesharingZone, fmt.Sprintf("%s.ts", fileKey)) if exists { ts, err := strconv.ParseInt(tsStr, 10, 64) if err != nil || ts < time.Now().Unix()-2592000 { log.Errorf("ignoring request to download a file offered more than 30 days ago") + return } } + // set the filekey status to active + cp.SetScopedZonedAttribute(attr.LocalScope, attr.FilesharingZone, fmt.Sprintf("%s.active", fileKey), constants.True) cp.eventBus.Publish(event.NewEvent(event.ShareManifest, map[event.Field]string{event.FileKey: fileKey, event.SerializedManifest: serializedManifest})) } +func (cp *cwtchPeer) StopFileShare(fileKey string) { + // set the filekey status to inactive + cp.SetScopedZonedAttribute(attr.LocalScope, attr.FilesharingZone, fmt.Sprintf("%s.active", fileKey), constants.False) + cp.eventBus.Publish(event.NewEvent(event.StopFileShare, map[event.Field]string{event.FileKey: fileKey})) +} + // eventHandler process events from other subsystems func (cp *cwtchPeer) eventHandler() { for { diff --git a/peer/cwtchprofilestorage.go b/peer/cwtchprofilestorage.go index 45f6adb..b9512c2 100644 --- a/peer/cwtchprofilestorage.go +++ b/peer/cwtchprofilestorage.go @@ -39,6 +39,7 @@ type CwtchProfileStorage struct { // Profile related statements insertProfileKeyValueStmt *sql.Stmt selectProfileKeyValueStmt *sql.Stmt + findProfileKeySQLStmt *sql.Stmt // Conversation related statements insertConversationStmt *sql.Stmt @@ -70,6 +71,7 @@ type ChannelID struct { const insertProfileKeySQLStmt = `insert or replace into profile_kv(KeyType, KeyName, KeyValue) values(?,?,?);` const selectProfileKeySQLStmt = `select KeyValue from profile_kv where KeyType=(?) and KeyName=(?);` +const findProfileKeySQLStmt = `select KeyName from profile_kv where KeyType=(?) and KeyName LIKE (?);` const insertConversationSQLStmt = `insert into conversations(Handle, Attributes, ACL, Accepted) values(?,?,?,?);` const fetchAllConversationsSQLStmt = `select ID, Handle, Attributes, ACL, Accepted from conversations;` @@ -130,6 +132,12 @@ func NewCwtchProfileStorage(db *sql.DB, profileDirectory string) (*CwtchProfileS return nil, err } + findProfileKeyStmt, err := db.Prepare(findProfileKeySQLStmt) + if err != nil { + log.Errorf("error preparing query: %v %v", findProfileKeySQLStmt, err) + return nil, err + } + insertConversationStmt, err := db.Prepare(insertConversationSQLStmt) if err != nil { log.Errorf("error preparing query: %v %v", insertConversationSQLStmt, err) @@ -182,6 +190,7 @@ func NewCwtchProfileStorage(db *sql.DB, profileDirectory string) (*CwtchProfileS ProfileDirectory: profileDirectory, insertProfileKeyValueStmt: insertProfileKeyValueStmt, selectProfileKeyValueStmt: selectProfileKeyStmt, + findProfileKeySQLStmt: findProfileKeyStmt, fetchAllConversationsStmt: fetchAllConversationsStmt, insertConversationStmt: insertConversationStmt, selectConversationStmt: selectConversationStmt, @@ -212,6 +221,34 @@ func (cps *CwtchProfileStorage) StoreProfileKeyValue(keyType StorageKeyType, key return nil } +// FindProfileKeysByPrefix allows fetching of typed values via a known Key from the Storage Engine +func (cps *CwtchProfileStorage) FindProfileKeysByPrefix(keyType StorageKeyType, prefix string) ([]string, error) { + rows, err := cps.findProfileKeySQLStmt.Query(keyType, prefix+"%") + if err != nil { + log.Errorf("error executing query: %v", err) + return nil, err + } + + var keys []string + defer rows.Close() + for { + result := rows.Next() + + if !result { + return keys, nil + } + + var key []byte + err = rows.Scan(&key) + if err != nil { + log.Errorf("error fetching rows: %v", err) + rows.Close() + return nil, err + } + keys = append(keys, string(key)) + } +} + // LoadProfileKeyValue allows fetching of typed values via a known Key from the Storage Engine func (cps *CwtchProfileStorage) LoadProfileKeyValue(keyType StorageKeyType, key string) ([]byte, error) { rows, err := cps.selectProfileKeyValueStmt.Query(keyType, key) @@ -796,7 +833,7 @@ func (cps *CwtchProfileStorage) Export(filename string) error { defer tarWriter.Close() // We need to know the base directory so we can import it later (and prevent duplicates)... - profilePath := filepath.Base(cps.ProfileDirectory) + profilePath := filepath.Base(cps.ProfileDirectory) err = addFileToTarWriter(profilePath, profileDB, tarWriter) if err != nil { diff --git a/peer/profile_interface.go b/peer/profile_interface.go index 8f62aa2..e2b02eb 100644 --- a/peer/profile_interface.go +++ b/peer/profile_interface.go @@ -81,6 +81,9 @@ type CwtchPeer interface { // scope.zone.key = value GetScopedZonedAttribute(scope attr.Scope, zone attr.Zone, key string) (string, bool) + // GetScopedZonedAttributeKeys returns all keys associated with a given scope and zone + GetScopedZonedAttributeKeys(scope attr.Scope, zone attr.Zone) ([]string, error) + AccessPeeringState ModifyPeeringState @@ -114,6 +117,7 @@ type CwtchPeer interface { UpdateMessageAttribute(conversation int, channel int, id int, key string, value string) error ShareFile(fileKey string, serializedManifest string) + StopFileShare(fileKey string) CheckPassword(password string) bool ChangePassword(oldpassword string, newpassword string, newpasswordAgain string) error Export(file string) error diff --git a/protocol/connections/engine.go b/protocol/connections/engine.go index 76481a9..f60a183 100644 --- a/protocol/connections/engine.go +++ b/protocol/connections/engine.go @@ -110,6 +110,7 @@ func NewProtocolEngine(identity primitives.Identity, privateKey ed25519.PrivateK // File Handling engine.eventManager.Subscribe(event.ShareManifest, engine.queue) + engine.eventManager.Subscribe(event.StopFileShare, engine.queue) engine.eventManager.Subscribe(event.ManifestSizeReceived, engine.queue) engine.eventManager.Subscribe(event.ManifestSaved, engine.queue) @@ -217,6 +218,8 @@ func (e *engine) eventHandler() { go e.listenFn() case event.ShareManifest: e.filesharingSubSystem.ShareFile(ev.Data[event.FileKey], ev.Data[event.SerializedManifest]) + case event.StopFileShare: + e.filesharingSubSystem.StopFileShare(ev.Data[event.FileKey]) case event.ManifestSizeReceived: handle := ev.Data[event.Handle] key := ev.Data[event.FileKey] diff --git a/protocol/files/filesharing_subsystem.go b/protocol/files/filesharing_subsystem.go index 5f71198..d9704ff 100644 --- a/protocol/files/filesharing_subsystem.go +++ b/protocol/files/filesharing_subsystem.go @@ -15,7 +15,6 @@ import ( ) // FileSharingSubSystem encapsulates the functionality necessary to share and download files via Cwtch -// type FileSharingSubSystem struct { // for sharing files @@ -38,6 +37,12 @@ func (fsss *FileSharingSubSystem) ShareFile(fileKey string, serializedManifest s fsss.activeShares.Store(fileKey, &manifest) } +// StopFileShare given a file key removes the serialized manifest from consideration by the file sharing +// subsystem. Future requests on this manifest will fail, as will any in-progress chunk requests. +func (fsss *FileSharingSubSystem) StopFileShare(fileKey string) { + fsss.activeShares.Delete(fileKey) +} + // FetchManifest given a file key and knowledge of the manifest size in chunks (obtained via an attribute lookup) // construct a request to download the manifest. func (fsss *FileSharingSubSystem) FetchManifest(fileKey string, manifestSize uint64) model.PeerMessage { diff --git a/testing/cwtch_peer_server_integration_test.go b/testing/cwtch_peer_server_integration_test.go index 43ece87..f2d39f9 100644 --- a/testing/cwtch_peer_server_integration_test.go +++ b/testing/cwtch_peer_server_integration_test.go @@ -4,7 +4,6 @@ import ( // Import SQL Cipher "crypto/rand" app2 "cwtch.im/cwtch/app" - "cwtch.im/cwtch/app/utils" "cwtch.im/cwtch/event" "cwtch.im/cwtch/model" "cwtch.im/cwtch/model/attr" @@ -120,17 +119,17 @@ func TestCwtchPeerIntegration(t *testing.T) { log.Infoln("Creating Carol...") app.CreateTaggedPeer("Carol", "asdfasdf", "test") - alice := utils.WaitGetPeer(app, "Alice") + alice := app2.WaitGetPeer(app, "Alice") log.Infoln("Alice created:", alice.GetOnion()) alice.SetScopedZonedAttribute(attr.PublicScope, attr.ProfileZone, constants.Name, "Alice") alice.AutoHandleEvents([]event.Type{event.PeerStateChange, event.ServerStateChange, event.NewGroupInvite, event.NewRetValMessageFromPeer}) - bob := utils.WaitGetPeer(app, "Bob") + bob := app2.WaitGetPeer(app, "Bob") log.Infoln("Bob created:", bob.GetOnion()) bob.SetScopedZonedAttribute(attr.PublicScope, attr.ProfileZone, constants.Name, "Bob") bob.AutoHandleEvents([]event.Type{event.PeerStateChange, event.ServerStateChange, event.NewGroupInvite, event.NewRetValMessageFromPeer}) - carol := utils.WaitGetPeer(app, "Carol") + carol := app2.WaitGetPeer(app, "Carol") log.Infoln("Carol created:", carol.GetOnion()) carol.SetScopedZonedAttribute(attr.PublicScope, attr.ProfileZone, constants.Name, "Carol") carol.AutoHandleEvents([]event.Type{event.PeerStateChange, event.ServerStateChange, event.NewGroupInvite, event.NewRetValMessageFromPeer}) diff --git a/testing/encryptedstorage/encrypted_storage_integration_test.go b/testing/encryptedstorage/encrypted_storage_integration_test.go index 8d25d19..3e40d5d 100644 --- a/testing/encryptedstorage/encrypted_storage_integration_test.go +++ b/testing/encryptedstorage/encrypted_storage_integration_test.go @@ -4,7 +4,6 @@ import ( // Import SQL Cipher "crypto/rand" app2 "cwtch.im/cwtch/app" - "cwtch.im/cwtch/app/utils" "cwtch.im/cwtch/model" "cwtch.im/cwtch/model/constants" "cwtch.im/cwtch/peer" @@ -65,8 +64,8 @@ func TestEncryptedStorage(t *testing.T) { app.CreateTaggedPeer("alice", "password", constants.ProfileTypeV1Password) app.CreateTaggedPeer("bob", "password", constants.ProfileTypeV1Password) - alice := utils.WaitGetPeer(app, "alice") - bob := utils.WaitGetPeer(app, "bob") + alice := app2.WaitGetPeer(app, "alice") + bob := app2.WaitGetPeer(app, "bob") alice.Listen() bob.Listen() diff --git a/testing/filesharing/file_sharing_integration_test.go b/testing/filesharing/file_sharing_integration_test.go index c77112b..a1deebc 100644 --- a/testing/filesharing/file_sharing_integration_test.go +++ b/testing/filesharing/file_sharing_integration_test.go @@ -2,6 +2,7 @@ package filesharing import ( "crypto/rand" + utils2 "cwtch.im/cwtch/utils" "encoding/base64" "encoding/hex" "encoding/json" @@ -9,7 +10,6 @@ import ( "io/ioutil" app2 "cwtch.im/cwtch/app" - "cwtch.im/cwtch/app/utils" "cwtch.im/cwtch/event" "cwtch.im/cwtch/functionality/filesharing" "cwtch.im/cwtch/model" @@ -58,7 +58,7 @@ func TestFileSharing(t *testing.T) { os.RemoveAll("cwtch.out.png") os.RemoveAll("cwtch.out.png.manifest") - log.SetLevel(log.LevelDebug) + log.SetLevel(log.LevelInfo) os.Mkdir("tordir", 0700) dataDir := path.Join("tordir", "tor") @@ -97,15 +97,15 @@ func TestFileSharing(t *testing.T) { os.RemoveAll(path.Join(cwtchDir, "testing")) os.Mkdir(path.Join(cwtchDir, "testing"), 0700) - fmt.Println("Creating Alice...") + t.Logf("Creating Alice...") app.CreateTaggedPeer("alice", "asdfasdf", "testing") - fmt.Println("Creating Bob...") + t.Logf("Creating Bob...") app.CreateTaggedPeer("bob", "asdfasdf", "testing") t.Logf("** Waiting for Alice, Bob...") - alice := utils.WaitGetPeer(app, "alice") - bob := utils.WaitGetPeer(app, "bob") + alice := app2.WaitGetPeer(app, "alice") + bob := app2.WaitGetPeer(app, "bob") alice.AutoHandleEvents([]event.Type{event.PeerStateChange, event.NewRetValMessageFromPeer}) bob.AutoHandleEvents([]event.Type{event.PeerStateChange, event.NewRetValMessageFromPeer, event.ManifestReceived}) @@ -124,14 +124,14 @@ func TestFileSharing(t *testing.T) { alice.NewContactConversation(bob.GetOnion(), model.DefaultP2PAccessControl(), true) alice.PeerWithOnion(bob.GetOnion()) - fmt.Println("Waiting for alice and Bob to peer...") + t.Logf("Waiting for alice and Bob to peer...") waitForPeerPeerConnection(t, alice, bob) - fmt.Println("Alice and Bob are Connected!!") + t.Logf("Alice and Bob are Connected!!") filesharingFunctionality, _ := filesharing.FunctionalityGate(map[string]bool{constants.FileSharingExperiment: true}) - _, fileSharingMessage, err := filesharingFunctionality.ShareFile("cwtch.png", alice) + filekey, fileSharingMessage, err := filesharingFunctionality.ShareFile("cwtch.png", alice) alice.SendMessage(1, fileSharingMessage) if err != nil { @@ -141,6 +141,44 @@ func TestFileSharing(t *testing.T) { // Wait for the messages to arrive... time.Sleep(time.Second * 10) + // test that bob can download and verify the file + testBobDownloadFile(t, bob, filesharingFunctionality, queueOracle) + + // Test stopping and restarting file shares + t.Logf("Stopping File Share") + alice.StopFileShare(filekey) + + // Allow time for the stop request to filter through Engine + time.Sleep(time.Second * 5) + + // Restart + t.Logf("Restarting File Share") + filesharingFunctionality.ReShareFiles(alice) + + // run the same download test again...to check that we can actually download the file + testBobDownloadFile(t, bob, filesharingFunctionality, queueOracle) + + queueOracle.Shutdown() + app.Shutdown() + acn.Close() + time.Sleep(5 * time.Second) + numGoRoutinesPostACN := runtime.NumGoroutine() + + // Printing out the current goroutines + // Very useful if we are leaking any. + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + + if numGoRoutinesStart != numGoRoutinesPostACN { + t.Errorf("Number of GoRoutines at start (%v) does not match number of goRoutines after cleanup of peers and servers (%v), clean up failed, leak detected!", numGoRoutinesStart, numGoRoutinesPostACN) + } + +} + +func testBobDownloadFile(t *testing.T, bob peer.CwtchPeer, filesharingFunctionality *filesharing.Functionality, queueOracle event.Queue) { + + os.RemoveAll("cwtch.out.png") + os.RemoveAll("cwtch.out.png.manifest") + message, _, err := bob.GetChannelMessage(1, 0, 1) if err != nil { t.Fatalf("could not find file sharing message: %v", err) @@ -159,28 +197,21 @@ func TestFileSharing(t *testing.T) { } // Wait for the file downloaded event - ev := queueOracle.Next() - if ev.EventType != event.FileDownloaded { - t.Fatalf("Expected file download event") + ClientTimeout := utils2.TimeoutPolicy(time.Second * 60) + err = ClientTimeout.ExecuteAction(func() error { + ev := queueOracle.Next() + if ev.EventType != event.FileDownloaded { + t.Fatalf("Expected file download event") + } + + manifest, _ := files.CreateManifest("cwtch.out.png") + if hex.EncodeToString(manifest.RootHash) != "8f0ed73bbb30db45b6a740b1251cae02945f48e4f991464d5f3607685c45dcd136a325dab2e5f6429ce2b715e602b20b5b16bf7438fb6235fefe912adcedb5fd" { + t.Fatalf("file hash does not match expected %x: ", manifest.RootHash) + } + return nil + }) + + if err != nil { + t.Fatalf("timeout when attempting to download a file") } - - manifest, _ := files.CreateManifest("cwtch.out.png") - if hex.EncodeToString(manifest.RootHash) != "8f0ed73bbb30db45b6a740b1251cae02945f48e4f991464d5f3607685c45dcd136a325dab2e5f6429ce2b715e602b20b5b16bf7438fb6235fefe912adcedb5fd" { - t.Fatalf("file hash does not match expected %x: ", manifest.RootHash) - } - - queueOracle.Shutdown() - app.Shutdown() - acn.Close() - time.Sleep(5 * time.Second) - numGoRoutinesPostACN := runtime.NumGoroutine() - - // Printing out the current goroutines - // Very useful if we are leaking any. - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - - if numGoRoutinesStart != numGoRoutinesPostACN { - t.Errorf("Number of GoRoutines at start (%v) does not match number of goRoutines after cleanup of peers and servers (%v), clean up failed, leak detected!", numGoRoutinesStart, numGoRoutinesPostACN) - } - } diff --git a/app/plugins/timeout_policy_test.go b/utils/timeout_policy_test.go similarity index 97% rename from app/plugins/timeout_policy_test.go rename to utils/timeout_policy_test.go index 8f4383a..ab04646 100644 --- a/app/plugins/timeout_policy_test.go +++ b/utils/timeout_policy_test.go @@ -1,4 +1,4 @@ -package plugins +package utils import ( "runtime" diff --git a/utils/timeoutpolicy.go b/utils/timeoutpolicy.go new file mode 100644 index 0000000..5c9dca4 --- /dev/null +++ b/utils/timeoutpolicy.go @@ -0,0 +1,32 @@ +package utils + +import ( + "context" + "fmt" + "time" +) + +// TimeoutPolicy is an interface for enforcing common timeout patterns +type TimeoutPolicy time.Duration + +// ExecuteAction runs a function and returns an error if it hasn't returned +// by the time specified by TimeoutPolicy +func (tp *TimeoutPolicy) ExecuteAction(action func() error) error { + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*tp)) + defer cancel() + + // channel is buffered- this is important! + c := make(chan error, 1) + go func() { + // this write is non-blocking as this goroutine has sole access to the channel + c <- action() + }() + + select { + case err := <-c: + return err + case <-ctx.Done(): + return fmt.Errorf("ActionTimedOutError") + } +}