Fix and document safety problems with Connection.Do

There were several issues with the Do function that made it nearly
impossible to write safe code.

First, Do cannot be called recursively -- it will deadlock. There is
actually no way to implement a safe and recursive Do (or mutex) in Go,
because there is no primitive that will identify the current goroutine.

RequestOpenChannel used Do internally, which made it impossible to open
channels safely in many circumstances. That has been removed, so all
calls to RequestOpenChannel must be changed to happen under Do now.

Do now has more documentation and a new rule: no code exposed through
API can use Do, unless it has sole custody of the connection (such as
ProcessAuthAsClient).

Related to that problem, Do was impossible to call from inside handlers
(or anything else on the process goroutine) -- it would again just
deadlock. This is resolved by wrapping calls into user code to continue
handling invocations of Do (and only those) while the handler is
executing.

There is a third issue with connection close, but it will be addressed
in a separate commit

And finally, because it's impossible to timeout or interrupt a call to
Do, I also added a DoContext method that takes a go Context, which is
also passed through to the called function.
This commit is contained in:
John Brooks 2017-09-16 21:38:00 +02:00 committed by Sarah Jamie Lewis
parent 0f47f62465
commit c24773809e
3 changed files with 112 additions and 54 deletions

View File

@ -44,11 +44,10 @@ func (rai *RicochetApplicationInstance) ContactRequestError() {
}
func (rai *RicochetApplicationInstance) SendChatMessage(message string) {
// Technically this errors afte the second time but we can ignore it.
rai.connection.RequestOpenChannel("im.ricochet.chat", rai)
rai.connection.Do(func() error {
// Technically this errors afte the second time but we can ignore it.
rai.connection.RequestOpenChannel("im.ricochet.chat", rai)
channel := rai.connection.Channel("im.ricochet.chat", channels.Outbound)
if channel != nil {
chatchannel, ok := channel.Handler.(*channels.ChatChannel)

View File

@ -1,6 +1,7 @@
package connection
import (
"context"
"errors"
"fmt"
"github.com/golang/protobuf/proto"
@ -91,18 +92,53 @@ func (rc *Connection) start() {
}
}
// Do allows any function utilizing Connection to be run safetly.
// All operations which require access to Connection managed resources should
// use Do()
// Do allows any function utilizing Connection to be run safely, if you're
// careful. All operations which require access (directly or indirectly) to
// Connection while Process is running need to use Do. Calls to Do without
// Process running will block.
//
// Like a mutex, Do cannot be called recursively. This will deadlock. As
// a result, no API in this library that can be reached from the application
// should use Do, with few exceptions. This would make the API impossible
// to use safely in many cases.
//
// Do is safe to call from methods of connection.Handler and channel.Handler
// that are called by Process.
func (rc *Connection) Do(do func() error) error {
// Force process to soft-break so we can lock
rc.traceLog("request unlocking of process loop for do()")
rc.unlockChannel <- true
rc.traceLog("process loop is unlocked for do()")
ret := do()
rc.traceLog("giving up lock process loop after do() ")
rc.unlockResponseChannel <- true
return ret
defer func() {
rc.traceLog("giving up lock process loop after do() ")
rc.unlockResponseChannel <- true
}()
return do()
}
// DoContext behaves in the same way as Do, but also respects the provided
// context when blocked, and passes the context to the callback function.
//
// DoContext should be used when any call to Do may need to be cancelled
// or timed out.
func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) error) error {
// Force process to soft-break so we can lock
rc.traceLog("request unlocking of process loop for do()")
select {
case rc.unlockChannel <- true:
break
case <-ctx.Done():
rc.traceLog("giving up on unlocking process loop for do() because context cancelled")
return ctx.Err()
}
rc.traceLog("process loop is unlocked for do()")
defer func() {
rc.traceLog("giving up lock process loop after do() ")
rc.unlockResponseChannel <- true
}()
return do(ctx)
}
// RequestOpenChannel sends an OpenChannel message to the remote client.
@ -111,46 +147,61 @@ func (rc *Connection) Do(do func() error) error {
// channel was opened successfully, because channels open asynchronously).
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
var channel *channels.Channel
err := rc.Do(func() error {
// Check that we have the authentication already
if handler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[handler.RequiresAuthentication()]
if !authed {
return utils.UnauthorizedActionError
}
}
var err error
channel, err = rc.channelManager.OpenChannelRequest(handler)
// Check that we have the authentication already
if handler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[handler.RequiresAuthentication()]
if !authed {
return nil, utils.UnauthorizedActionError
}
}
if err != nil {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return err
}
channel, err := rc.channelManager.OpenChannelRequest(handler)
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
if err != nil {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return nil, err
}
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[handler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := handler.OpenOutbound(channel)
if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID)
}
return channel, nil
}
// processUserCallback should be used to wrap any calls into handlers or
// application code from the Process goroutine. It handles calls to Do
// from within that code to prevent deadlocks.
func (rc *Connection) processUserCallback(cb func()) {
done := make(chan struct{})
go func() {
defer close(done)
cb()
}()
for {
select {
case <-done:
return
case <-rc.unlockChannel:
<-rc.unlockResponseChannel
}
channel.DelegateAuthorization = func() {
rc.Authentication[handler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := handler.OpenOutbound(channel)
if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID)
}
return nil
})
return channel, err
}
}
// Process receives socket and protocol events for the connection. Methods
@ -164,7 +215,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler)
// If the connection is closed, a non-nil error is returned.
func (rc *Connection) Process(handler Handler) error {
rc.traceLog("entering process loop")
handler.OnReady(rc)
rc.processUserCallback(func() { handler.OnReady(rc) })
breaked := false
for !breaked {
@ -181,7 +232,7 @@ func (rc *Connection) Process(handler Handler) error {
break
case err := <-rc.errorChannel:
rc.Conn.Close()
handler.OnClosed(err)
rc.processUserCallback(func() { handler.OnClosed(err) })
return err
}
@ -190,7 +241,9 @@ func (rc *Connection) Process(handler Handler) error {
res := new(Protocol_Data_Control.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err == nil {
rc.controlPacket(handler, res)
// Wrap controlPacket in processUserCallback, since it calls out in many
// places, and wrapping the rest is harmless.
rc.processUserCallback(func() { rc.controlPacket(handler, res) })
}
} else {
// Let's check to see if we have defined this channel.
@ -199,11 +252,11 @@ func (rc *Connection) Process(handler Handler) error {
if len(packet.Data) == 0 {
rc.traceLog(fmt.Sprintf("removing channel %d", packet.Channel))
rc.channelManager.RemoveChannel(packet.Channel)
channel.Handler.Closed(utils.ChannelClosedByPeerError)
rc.processUserCallback(func() { channel.Handler.Closed(utils.ChannelClosedByPeerError) })
} else {
rc.traceLog(fmt.Sprintf("received packet on %v channel %d", channel.Handler.Type(), packet.Channel))
// Send The Ricochet Packet to the Handler
channel.Handler.Packet(packet.Data[:])
rc.processUserCallback(func() { channel.Handler.Packet(packet.Data[:]) })
}
} else {
// When a non-zero packet is received for an unknown

View File

@ -64,13 +64,19 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string)
go rc.Process(echobot)
if !known {
_, err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot})
err := rc.Do(func() error {
_, err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot})
return err
})
if err != nil {
log.Printf("could not contact %s", err)
}
}
rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot})
rc.Do(func() error {
_, err := rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot})
return err
})
for {
message := <-echobot.messages
log.Printf("Received Message: %s", message)