diff --git a/connection/connection.go b/connection/connection.go index 705cb83..f2d1460 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -10,6 +10,7 @@ import ( "github.com/s-rah/go-ricochet/wire/control" "io" "log" + "sync" ) // Connection encapsulates the state required to maintain a connection to @@ -24,7 +25,7 @@ type Connection struct { errorChannel chan error breakChannel chan bool - breakResultChannel chan bool + breakResultChannel chan error unlockChannel chan bool unlockResponseChannel chan bool @@ -32,6 +33,13 @@ type Connection struct { messageBuilder utils.MessageBuilder trace bool + closed bool + closing bool + // This mutex is exclusively for preventing races during blocking + // interactions with Process; specifically Do and Break. Don't use + // it for anything else. See those functions for an explanation. + processBlockMutex sync.Mutex + Conn io.ReadWriteCloser IsInbound bool Authentication map[string]bool @@ -44,7 +52,7 @@ func (rc *Connection) init() { rc.errorChannel = make(chan error) rc.breakChannel = make(chan bool) - rc.breakResultChannel = make(chan bool) + rc.breakResultChannel = make(chan error) rc.unlockChannel = make(chan bool) rc.unlockResponseChannel = make(chan bool) @@ -95,7 +103,8 @@ func (rc *Connection) start() { // Do allows any function utilizing Connection to be run safely, if you're // careful. All operations which require access (directly or indirectly) to // Connection while Process is running need to use Do. Calls to Do without -// Process running will block. +// Process running will block unless the connection is closed, which is +// returned as ConnectionClosedError. // // Like a mutex, Do cannot be called recursively. This will deadlock. As // a result, no API in this library that can be reached from the application @@ -105,6 +114,30 @@ func (rc *Connection) start() { // Do is safe to call from methods of connection.Handler and channel.Handler // that are called by Process. func (rc *Connection) Do(do func() error) error { + // There's a complicated little dance here to prevent a race when the + // Process call is returning for a connection error. The problem is + // that if Do simply checked rc.closed and then tried to send, it's + // possible for Process to change rc.closed and stop reading before the + // send statement is executed, creating a deadlock. + // + // To prevent this, all of the functions that block on Process should + // do so by acquiring processBlockMutex, aborting if rc.closed is true, + // performing their blocking channel operations, and then releasing the + // mutex. + // + // This works because Process will always use a separate goroutine to + // acquire processBlockMutex before changing rc.closed, and the mutex + // guarantees that no blocking channel operation can happen during or + // after the value is changed. Since these operations block the Process + // loop, the behavior of multiple concurrent calls to Do/Break doesn't + // change: they just end up blocking on the mutex before blocking on the + // channel. + rc.processBlockMutex.Lock() + defer rc.processBlockMutex.Unlock() + if rc.closed { + return utils.ConnectionClosedError + } + // Force process to soft-break so we can lock rc.traceLog("request unlocking of process loop for do()") rc.unlockChannel <- true @@ -113,6 +146,13 @@ func (rc *Connection) Do(do func() error) error { rc.traceLog("giving up lock process loop after do() ") rc.unlockResponseChannel <- true }() + + // Process sets rc.closing when it's trying to acquire the mutex and + // close down the connection. Behave as if the connection was already + // closed. + if rc.closing { + return utils.ConnectionClosedError + } return do() } @@ -122,6 +162,13 @@ func (rc *Connection) Do(do func() error) error { // DoContext should be used when any call to Do may need to be cancelled // or timed out. func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) error) error { + // .. see above + rc.processBlockMutex.Lock() + defer rc.processBlockMutex.Unlock() + if rc.closed { + return utils.ConnectionClosedError + } + // Force process to soft-break so we can lock rc.traceLog("request unlocking of process loop for do()") select { @@ -138,6 +185,9 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er rc.unlockResponseChannel <- true }() + if rc.closing { + return utils.ConnectionClosedError + } return do(ctx) } @@ -214,10 +264,22 @@ func (rc *Connection) processUserCallback(cb func()) { // Process blocks until the connection is closed or until Break() is called. // If the connection is closed, a non-nil error is returned. func (rc *Connection) Process(handler Handler) error { + if rc.closed { + return utils.ConnectionClosedError + } rc.traceLog("entering process loop") rc.processUserCallback(func() { handler.OnReady(rc) }) - breaked := false - for !breaked { + + // There are exactly two ways out of this loop: a signal on breakChannel + // caused by a call to Break, or a connection-fatal error on errorChannel. + // + // In the Break case, no particular care is necessary; it is the caller's + // responsibility to make sure there aren't e.g. concurrent calls to Do. + // + // Because connection errors can happen spontaneously, they must carefully + // prevent concurrent calls to Break or Do that could deadlock when Process + // returns. + for { var packet utils.RicochetData select { @@ -226,13 +288,46 @@ func (rc *Connection) Process(handler Handler) error { continue case <-rc.breakChannel: rc.traceLog("process has ended after break") - breaked = true - continue + rc.breakResultChannel <- nil + return nil case packet = <-rc.packetChannel: break case err := <-rc.errorChannel: rc.Conn.Close() - rc.processUserCallback(func() { handler.OnClosed(err) }) + rc.closing = true + + // In order to safely close down concurrent calls to Do or Break, + // processBlockMutex must be held before setting rc.closed. That cannot + // happen in this goroutine, because one of those calls may already hold + // the mutex and be blocking on a channel send to this method. So the + // process here is to have a goroutine acquire the lock, set rc.closed, and + // signal back. Meanwhile, this one keeps handling unlockChannel and + // breakChannel. + closedChan := make(chan struct{}) + go func() { + rc.processBlockMutex.Lock() + defer rc.processBlockMutex.Unlock() + rc.closed = true + close(closedChan) + }() + + // Keep accepting calls from Do or Break until closedChan signals that they're + // safely shut down. + clearLoop: + for { + select { + case <-rc.unlockChannel: + <-rc.unlockResponseChannel + case <-rc.breakChannel: + rc.breakResultChannel <- utils.ConnectionClosedError + case <-closedChan: + break clearLoop + } + } + + // This is the one case where processUserCallback isn't necessary, because + // all calls to Do immediately return ConnectionClosedError now. + handler.OnClosed(err) return err } @@ -269,10 +364,6 @@ func (rc *Connection) Process(handler Handler) error { } } } - - rc.breakResultChannel <- true - return nil - } func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.Packet) { @@ -383,10 +474,21 @@ func (rc *Connection) traceLog(message string) { } // Break causes Process() to return, but does not close the underlying connection -func (rc *Connection) Break() { +// Break returns an error if it would not be valid to call Process() again for +// the connection now. Currently, the only such error is ConnectionClosedError. +func (rc *Connection) Break() error { + // See Do() for an explanation of the concurrency here; it's complicated. + // The summary is that this mutex prevents races on connection close that + // could lead to deadlocks in Block(). + rc.processBlockMutex.Lock() + defer rc.processBlockMutex.Unlock() + if rc.closed { + rc.traceLog("ignoring break because connection is already closed") + return utils.ConnectionClosedError + } rc.traceLog("breaking out of process loop") rc.breakChannel <- true - <-rc.breakResultChannel // Wait for Process to End + return <-rc.breakResultChannel // Wait for Process to End } // Channel is a convienciance method for returning a given channel to the caller diff --git a/utils/error.go b/utils/error.go index 7cf28cf..d65fb0d 100644 --- a/utils/error.go +++ b/utils/error.go @@ -39,6 +39,9 @@ const ( // Library Use Errors PrivateKeyNotSetError = Error("ClientFailedToAuthenticateError") + + // Connection Errors + ConnectionClosedError = Error("ConnectionClosedError") ) // CheckError is a helper function for panicing on errors which we need to handle