From c24773809ed99ffff9edf96ec702ebe453f1f179 Mon Sep 17 00:00:00 2001 From: John Brooks Date: Sat, 16 Sep 2017 21:38:00 +0200 Subject: [PATCH] Fix and document safety problems with Connection.Do There were several issues with the Do function that made it nearly impossible to write safe code. First, Do cannot be called recursively -- it will deadlock. There is actually no way to implement a safe and recursive Do (or mutex) in Go, because there is no primitive that will identify the current goroutine. RequestOpenChannel used Do internally, which made it impossible to open channels safely in many circumstances. That has been removed, so all calls to RequestOpenChannel must be changed to happen under Do now. Do now has more documentation and a new rule: no code exposed through API can use Do, unless it has sole custody of the connection (such as ProcessAuthAsClient). Related to that problem, Do was impossible to call from inside handlers (or anything else on the process goroutine) -- it would again just deadlock. This is resolved by wrapping calls into user code to continue handling invocations of Do (and only those) while the handler is executing. There is a third issue with connection close, but it will be addressed in a separate commit And finally, because it's impossible to timeout or interrupt a call to Do, I also added a DoContext method that takes a go Context, which is also passed through to the called function. --- application/application.go | 7 +- connection/connection.go | 149 +++++++++++++++++++++++++------------ examples/echobot/main.go | 10 ++- 3 files changed, 112 insertions(+), 54 deletions(-) diff --git a/application/application.go b/application/application.go index c506b63..15f1740 100644 --- a/application/application.go +++ b/application/application.go @@ -44,11 +44,10 @@ func (rai *RicochetApplicationInstance) ContactRequestError() { } func (rai *RicochetApplicationInstance) SendChatMessage(message string) { - - // Technically this errors afte the second time but we can ignore it. - rai.connection.RequestOpenChannel("im.ricochet.chat", rai) - rai.connection.Do(func() error { + // Technically this errors afte the second time but we can ignore it. + rai.connection.RequestOpenChannel("im.ricochet.chat", rai) + channel := rai.connection.Channel("im.ricochet.chat", channels.Outbound) if channel != nil { chatchannel, ok := channel.Handler.(*channels.ChatChannel) diff --git a/connection/connection.go b/connection/connection.go index 48f462d..705cb83 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -1,6 +1,7 @@ package connection import ( + "context" "errors" "fmt" "github.com/golang/protobuf/proto" @@ -91,18 +92,53 @@ func (rc *Connection) start() { } } -// Do allows any function utilizing Connection to be run safetly. -// All operations which require access to Connection managed resources should -// use Do() +// Do allows any function utilizing Connection to be run safely, if you're +// careful. All operations which require access (directly or indirectly) to +// Connection while Process is running need to use Do. Calls to Do without +// Process running will block. +// +// Like a mutex, Do cannot be called recursively. This will deadlock. As +// a result, no API in this library that can be reached from the application +// should use Do, with few exceptions. This would make the API impossible +// to use safely in many cases. +// +// Do is safe to call from methods of connection.Handler and channel.Handler +// that are called by Process. func (rc *Connection) Do(do func() error) error { // Force process to soft-break so we can lock rc.traceLog("request unlocking of process loop for do()") rc.unlockChannel <- true rc.traceLog("process loop is unlocked for do()") - ret := do() - rc.traceLog("giving up lock process loop after do() ") - rc.unlockResponseChannel <- true - return ret + defer func() { + rc.traceLog("giving up lock process loop after do() ") + rc.unlockResponseChannel <- true + }() + return do() +} + +// DoContext behaves in the same way as Do, but also respects the provided +// context when blocked, and passes the context to the callback function. +// +// DoContext should be used when any call to Do may need to be cancelled +// or timed out. +func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) error) error { + // Force process to soft-break so we can lock + rc.traceLog("request unlocking of process loop for do()") + select { + case rc.unlockChannel <- true: + break + case <-ctx.Done(): + rc.traceLog("giving up on unlocking process loop for do() because context cancelled") + return ctx.Err() + } + + rc.traceLog("process loop is unlocked for do()") + defer func() { + rc.traceLog("giving up lock process loop after do() ") + rc.unlockResponseChannel <- true + }() + + return do(ctx) } // RequestOpenChannel sends an OpenChannel message to the remote client. @@ -111,46 +147,61 @@ func (rc *Connection) Do(do func() error) error { // channel was opened successfully, because channels open asynchronously). func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) { rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype)) - var channel *channels.Channel - err := rc.Do(func() error { - // Check that we have the authentication already - if handler.RequiresAuthentication() != "none" { - // Enforce Authentication Check. - _, authed := rc.Authentication[handler.RequiresAuthentication()] - if !authed { - return utils.UnauthorizedActionError - } - } - var err error - channel, err = rc.channelManager.OpenChannelRequest(handler) + // Check that we have the authentication already + if handler.RequiresAuthentication() != "none" { + // Enforce Authentication Check. + _, authed := rc.Authentication[handler.RequiresAuthentication()] + if !authed { + return nil, utils.UnauthorizedActionError + } + } - if err != nil { - rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) - return err - } + channel, err := rc.channelManager.OpenChannelRequest(handler) - channel.SendMessage = func(message []byte) { - rc.SendRicochetPacket(rc.Conn, channel.ID, message) + if err != nil { + rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) + return nil, err + } + + channel.SendMessage = func(message []byte) { + rc.SendRicochetPacket(rc.Conn, channel.ID, message) + } + channel.DelegateAuthorization = func() { + rc.Authentication[handler.Type()] = true + } + channel.CloseChannel = func() { + rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) + rc.channelManager.RemoveChannel(channel.ID) + } + response, err := handler.OpenOutbound(channel) + if err == nil { + rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype)) + rc.SendRicochetPacket(rc.Conn, 0, response) + } else { + rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) + rc.channelManager.RemoveChannel(channel.ID) + } + return channel, nil +} + +// processUserCallback should be used to wrap any calls into handlers or +// application code from the Process goroutine. It handles calls to Do +// from within that code to prevent deadlocks. +func (rc *Connection) processUserCallback(cb func()) { + done := make(chan struct{}) + go func() { + defer close(done) + cb() + }() + for { + select { + case <-done: + return + case <-rc.unlockChannel: + <-rc.unlockResponseChannel } - channel.DelegateAuthorization = func() { - rc.Authentication[handler.Type()] = true - } - channel.CloseChannel = func() { - rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) - rc.channelManager.RemoveChannel(channel.ID) - } - response, err := handler.OpenOutbound(channel) - if err == nil { - rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype)) - rc.SendRicochetPacket(rc.Conn, 0, response) - } else { - rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) - rc.channelManager.RemoveChannel(channel.ID) - } - return nil - }) - return channel, err + } } // Process receives socket and protocol events for the connection. Methods @@ -164,7 +215,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) // If the connection is closed, a non-nil error is returned. func (rc *Connection) Process(handler Handler) error { rc.traceLog("entering process loop") - handler.OnReady(rc) + rc.processUserCallback(func() { handler.OnReady(rc) }) breaked := false for !breaked { @@ -181,7 +232,7 @@ func (rc *Connection) Process(handler Handler) error { break case err := <-rc.errorChannel: rc.Conn.Close() - handler.OnClosed(err) + rc.processUserCallback(func() { handler.OnClosed(err) }) return err } @@ -190,7 +241,9 @@ func (rc *Connection) Process(handler Handler) error { res := new(Protocol_Data_Control.Packet) err := proto.Unmarshal(packet.Data[:], res) if err == nil { - rc.controlPacket(handler, res) + // Wrap controlPacket in processUserCallback, since it calls out in many + // places, and wrapping the rest is harmless. + rc.processUserCallback(func() { rc.controlPacket(handler, res) }) } } else { // Let's check to see if we have defined this channel. @@ -199,11 +252,11 @@ func (rc *Connection) Process(handler Handler) error { if len(packet.Data) == 0 { rc.traceLog(fmt.Sprintf("removing channel %d", packet.Channel)) rc.channelManager.RemoveChannel(packet.Channel) - channel.Handler.Closed(utils.ChannelClosedByPeerError) + rc.processUserCallback(func() { channel.Handler.Closed(utils.ChannelClosedByPeerError) }) } else { rc.traceLog(fmt.Sprintf("received packet on %v channel %d", channel.Handler.Type(), packet.Channel)) // Send The Ricochet Packet to the Handler - channel.Handler.Packet(packet.Data[:]) + rc.processUserCallback(func() { channel.Handler.Packet(packet.Data[:]) }) } } else { // When a non-zero packet is received for an unknown diff --git a/examples/echobot/main.go b/examples/echobot/main.go index 8c4d1e7..609e3c3 100644 --- a/examples/echobot/main.go +++ b/examples/echobot/main.go @@ -64,13 +64,19 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string) go rc.Process(echobot) if !known { - _, err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot}) + err := rc.Do(func() error { + _, err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot}) + return err + }) if err != nil { log.Printf("could not contact %s", err) } } - rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot}) + rc.Do(func() error { + _, err := rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot}) + return err + }) for { message := <-echobot.messages log.Printf("Received Message: %s", message)