diff --git a/application/application.go b/application/application.go index c506b63..15f1740 100644 --- a/application/application.go +++ b/application/application.go @@ -44,11 +44,10 @@ func (rai *RicochetApplicationInstance) ContactRequestError() { } func (rai *RicochetApplicationInstance) SendChatMessage(message string) { - - // Technically this errors afte the second time but we can ignore it. - rai.connection.RequestOpenChannel("im.ricochet.chat", rai) - rai.connection.Do(func() error { + // Technically this errors afte the second time but we can ignore it. + rai.connection.RequestOpenChannel("im.ricochet.chat", rai) + channel := rai.connection.Channel("im.ricochet.chat", channels.Outbound) if channel != nil { chatchannel, ok := channel.Handler.(*channels.ChatChannel) diff --git a/connection/connection.go b/connection/connection.go index 48f462d..705cb83 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -1,6 +1,7 @@ package connection import ( + "context" "errors" "fmt" "github.com/golang/protobuf/proto" @@ -91,18 +92,53 @@ func (rc *Connection) start() { } } -// Do allows any function utilizing Connection to be run safetly. -// All operations which require access to Connection managed resources should -// use Do() +// Do allows any function utilizing Connection to be run safely, if you're +// careful. All operations which require access (directly or indirectly) to +// Connection while Process is running need to use Do. Calls to Do without +// Process running will block. +// +// Like a mutex, Do cannot be called recursively. This will deadlock. As +// a result, no API in this library that can be reached from the application +// should use Do, with few exceptions. This would make the API impossible +// to use safely in many cases. +// +// Do is safe to call from methods of connection.Handler and channel.Handler +// that are called by Process. func (rc *Connection) Do(do func() error) error { // Force process to soft-break so we can lock rc.traceLog("request unlocking of process loop for do()") rc.unlockChannel <- true rc.traceLog("process loop is unlocked for do()") - ret := do() - rc.traceLog("giving up lock process loop after do() ") - rc.unlockResponseChannel <- true - return ret + defer func() { + rc.traceLog("giving up lock process loop after do() ") + rc.unlockResponseChannel <- true + }() + return do() +} + +// DoContext behaves in the same way as Do, but also respects the provided +// context when blocked, and passes the context to the callback function. +// +// DoContext should be used when any call to Do may need to be cancelled +// or timed out. +func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) error) error { + // Force process to soft-break so we can lock + rc.traceLog("request unlocking of process loop for do()") + select { + case rc.unlockChannel <- true: + break + case <-ctx.Done(): + rc.traceLog("giving up on unlocking process loop for do() because context cancelled") + return ctx.Err() + } + + rc.traceLog("process loop is unlocked for do()") + defer func() { + rc.traceLog("giving up lock process loop after do() ") + rc.unlockResponseChannel <- true + }() + + return do(ctx) } // RequestOpenChannel sends an OpenChannel message to the remote client. @@ -111,46 +147,61 @@ func (rc *Connection) Do(do func() error) error { // channel was opened successfully, because channels open asynchronously). func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) { rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype)) - var channel *channels.Channel - err := rc.Do(func() error { - // Check that we have the authentication already - if handler.RequiresAuthentication() != "none" { - // Enforce Authentication Check. - _, authed := rc.Authentication[handler.RequiresAuthentication()] - if !authed { - return utils.UnauthorizedActionError - } - } - var err error - channel, err = rc.channelManager.OpenChannelRequest(handler) + // Check that we have the authentication already + if handler.RequiresAuthentication() != "none" { + // Enforce Authentication Check. + _, authed := rc.Authentication[handler.RequiresAuthentication()] + if !authed { + return nil, utils.UnauthorizedActionError + } + } - if err != nil { - rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) - return err - } + channel, err := rc.channelManager.OpenChannelRequest(handler) - channel.SendMessage = func(message []byte) { - rc.SendRicochetPacket(rc.Conn, channel.ID, message) + if err != nil { + rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) + return nil, err + } + + channel.SendMessage = func(message []byte) { + rc.SendRicochetPacket(rc.Conn, channel.ID, message) + } + channel.DelegateAuthorization = func() { + rc.Authentication[handler.Type()] = true + } + channel.CloseChannel = func() { + rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) + rc.channelManager.RemoveChannel(channel.ID) + } + response, err := handler.OpenOutbound(channel) + if err == nil { + rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype)) + rc.SendRicochetPacket(rc.Conn, 0, response) + } else { + rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) + rc.channelManager.RemoveChannel(channel.ID) + } + return channel, nil +} + +// processUserCallback should be used to wrap any calls into handlers or +// application code from the Process goroutine. It handles calls to Do +// from within that code to prevent deadlocks. +func (rc *Connection) processUserCallback(cb func()) { + done := make(chan struct{}) + go func() { + defer close(done) + cb() + }() + for { + select { + case <-done: + return + case <-rc.unlockChannel: + <-rc.unlockResponseChannel } - channel.DelegateAuthorization = func() { - rc.Authentication[handler.Type()] = true - } - channel.CloseChannel = func() { - rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) - rc.channelManager.RemoveChannel(channel.ID) - } - response, err := handler.OpenOutbound(channel) - if err == nil { - rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype)) - rc.SendRicochetPacket(rc.Conn, 0, response) - } else { - rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) - rc.channelManager.RemoveChannel(channel.ID) - } - return nil - }) - return channel, err + } } // Process receives socket and protocol events for the connection. Methods @@ -164,7 +215,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) // If the connection is closed, a non-nil error is returned. func (rc *Connection) Process(handler Handler) error { rc.traceLog("entering process loop") - handler.OnReady(rc) + rc.processUserCallback(func() { handler.OnReady(rc) }) breaked := false for !breaked { @@ -181,7 +232,7 @@ func (rc *Connection) Process(handler Handler) error { break case err := <-rc.errorChannel: rc.Conn.Close() - handler.OnClosed(err) + rc.processUserCallback(func() { handler.OnClosed(err) }) return err } @@ -190,7 +241,9 @@ func (rc *Connection) Process(handler Handler) error { res := new(Protocol_Data_Control.Packet) err := proto.Unmarshal(packet.Data[:], res) if err == nil { - rc.controlPacket(handler, res) + // Wrap controlPacket in processUserCallback, since it calls out in many + // places, and wrapping the rest is harmless. + rc.processUserCallback(func() { rc.controlPacket(handler, res) }) } } else { // Let's check to see if we have defined this channel. @@ -199,11 +252,11 @@ func (rc *Connection) Process(handler Handler) error { if len(packet.Data) == 0 { rc.traceLog(fmt.Sprintf("removing channel %d", packet.Channel)) rc.channelManager.RemoveChannel(packet.Channel) - channel.Handler.Closed(utils.ChannelClosedByPeerError) + rc.processUserCallback(func() { channel.Handler.Closed(utils.ChannelClosedByPeerError) }) } else { rc.traceLog(fmt.Sprintf("received packet on %v channel %d", channel.Handler.Type(), packet.Channel)) // Send The Ricochet Packet to the Handler - channel.Handler.Packet(packet.Data[:]) + rc.processUserCallback(func() { channel.Handler.Packet(packet.Data[:]) }) } } else { // When a non-zero packet is received for an unknown diff --git a/examples/echobot/main.go b/examples/echobot/main.go index 8c4d1e7..609e3c3 100644 --- a/examples/echobot/main.go +++ b/examples/echobot/main.go @@ -64,13 +64,19 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string) go rc.Process(echobot) if !known { - _, err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot}) + err := rc.Do(func() error { + _, err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot}) + return err + }) if err != nil { log.Printf("could not contact %s", err) } } - rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot}) + rc.Do(func() error { + _, err := rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot}) + return err + }) for { message := <-echobot.messages log.Printf("Received Message: %s", message)