Fix and document safety problems with Connection.Do

There were several issues with the Do function that made it nearly
impossible to write safe code.

First, Do cannot be called recursively -- it will deadlock. There is
actually no way to implement a safe and recursive Do (or mutex) in Go,
because there is no primitive that will identify the current goroutine.

RequestOpenChannel used Do internally, which made it impossible to open
channels safely in many circumstances. That has been removed, so all
calls to RequestOpenChannel must be changed to happen under Do now.

Do now has more documentation and a new rule: no code exposed through
API can use Do, unless it has sole custody of the connection (such as
ProcessAuthAsClient).

Related to that problem, Do was impossible to call from inside handlers
(or anything else on the process goroutine) -- it would again just
deadlock. This is resolved by wrapping calls into user code to continue
handling invocations of Do (and only those) while the handler is
executing.

There is a third issue with connection close, but it will be addressed
in a separate commit

And finally, because it's impossible to timeout or interrupt a call to
Do, I also added a DoContext method that takes a go Context, which is
also passed through to the called function.
This commit is contained in:
John Brooks 2017-09-16 21:38:00 +02:00 committed by Sarah Jamie Lewis
parent 0f47f62465
commit c24773809e
3 changed files with 112 additions and 54 deletions

View File

@ -44,11 +44,10 @@ func (rai *RicochetApplicationInstance) ContactRequestError() {
} }
func (rai *RicochetApplicationInstance) SendChatMessage(message string) { func (rai *RicochetApplicationInstance) SendChatMessage(message string) {
// Technically this errors afte the second time but we can ignore it.
rai.connection.RequestOpenChannel("im.ricochet.chat", rai)
rai.connection.Do(func() error { rai.connection.Do(func() error {
// Technically this errors afte the second time but we can ignore it.
rai.connection.RequestOpenChannel("im.ricochet.chat", rai)
channel := rai.connection.Channel("im.ricochet.chat", channels.Outbound) channel := rai.connection.Channel("im.ricochet.chat", channels.Outbound)
if channel != nil { if channel != nil {
chatchannel, ok := channel.Handler.(*channels.ChatChannel) chatchannel, ok := channel.Handler.(*channels.ChatChannel)

View File

@ -1,6 +1,7 @@
package connection package connection
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -91,18 +92,53 @@ func (rc *Connection) start() {
} }
} }
// Do allows any function utilizing Connection to be run safetly. // Do allows any function utilizing Connection to be run safely, if you're
// All operations which require access to Connection managed resources should // careful. All operations which require access (directly or indirectly) to
// use Do() // Connection while Process is running need to use Do. Calls to Do without
// Process running will block.
//
// Like a mutex, Do cannot be called recursively. This will deadlock. As
// a result, no API in this library that can be reached from the application
// should use Do, with few exceptions. This would make the API impossible
// to use safely in many cases.
//
// Do is safe to call from methods of connection.Handler and channel.Handler
// that are called by Process.
func (rc *Connection) Do(do func() error) error { func (rc *Connection) Do(do func() error) error {
// Force process to soft-break so we can lock // Force process to soft-break so we can lock
rc.traceLog("request unlocking of process loop for do()") rc.traceLog("request unlocking of process loop for do()")
rc.unlockChannel <- true rc.unlockChannel <- true
rc.traceLog("process loop is unlocked for do()") rc.traceLog("process loop is unlocked for do()")
ret := do() defer func() {
rc.traceLog("giving up lock process loop after do() ") rc.traceLog("giving up lock process loop after do() ")
rc.unlockResponseChannel <- true rc.unlockResponseChannel <- true
return ret }()
return do()
}
// DoContext behaves in the same way as Do, but also respects the provided
// context when blocked, and passes the context to the callback function.
//
// DoContext should be used when any call to Do may need to be cancelled
// or timed out.
func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) error) error {
// Force process to soft-break so we can lock
rc.traceLog("request unlocking of process loop for do()")
select {
case rc.unlockChannel <- true:
break
case <-ctx.Done():
rc.traceLog("giving up on unlocking process loop for do() because context cancelled")
return ctx.Err()
}
rc.traceLog("process loop is unlocked for do()")
defer func() {
rc.traceLog("giving up lock process loop after do() ")
rc.unlockResponseChannel <- true
}()
return do(ctx)
} }
// RequestOpenChannel sends an OpenChannel message to the remote client. // RequestOpenChannel sends an OpenChannel message to the remote client.
@ -111,46 +147,61 @@ func (rc *Connection) Do(do func() error) error {
// channel was opened successfully, because channels open asynchronously). // channel was opened successfully, because channels open asynchronously).
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) { func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype)) rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
var channel *channels.Channel
err := rc.Do(func() error {
// Check that we have the authentication already
if handler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[handler.RequiresAuthentication()]
if !authed {
return utils.UnauthorizedActionError
}
}
var err error // Check that we have the authentication already
channel, err = rc.channelManager.OpenChannelRequest(handler) if handler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[handler.RequiresAuthentication()]
if !authed {
return nil, utils.UnauthorizedActionError
}
}
if err != nil { channel, err := rc.channelManager.OpenChannelRequest(handler)
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return err
}
channel.SendMessage = func(message []byte) { if err != nil {
rc.SendRicochetPacket(rc.Conn, channel.ID, message) rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return nil, err
}
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[handler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := handler.OpenOutbound(channel)
if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID)
}
return channel, nil
}
// processUserCallback should be used to wrap any calls into handlers or
// application code from the Process goroutine. It handles calls to Do
// from within that code to prevent deadlocks.
func (rc *Connection) processUserCallback(cb func()) {
done := make(chan struct{})
go func() {
defer close(done)
cb()
}()
for {
select {
case <-done:
return
case <-rc.unlockChannel:
<-rc.unlockResponseChannel
} }
channel.DelegateAuthorization = func() { }
rc.Authentication[handler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := handler.OpenOutbound(channel)
if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID)
}
return nil
})
return channel, err
} }
// Process receives socket and protocol events for the connection. Methods // Process receives socket and protocol events for the connection. Methods
@ -164,7 +215,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler)
// If the connection is closed, a non-nil error is returned. // If the connection is closed, a non-nil error is returned.
func (rc *Connection) Process(handler Handler) error { func (rc *Connection) Process(handler Handler) error {
rc.traceLog("entering process loop") rc.traceLog("entering process loop")
handler.OnReady(rc) rc.processUserCallback(func() { handler.OnReady(rc) })
breaked := false breaked := false
for !breaked { for !breaked {
@ -181,7 +232,7 @@ func (rc *Connection) Process(handler Handler) error {
break break
case err := <-rc.errorChannel: case err := <-rc.errorChannel:
rc.Conn.Close() rc.Conn.Close()
handler.OnClosed(err) rc.processUserCallback(func() { handler.OnClosed(err) })
return err return err
} }
@ -190,7 +241,9 @@ func (rc *Connection) Process(handler Handler) error {
res := new(Protocol_Data_Control.Packet) res := new(Protocol_Data_Control.Packet)
err := proto.Unmarshal(packet.Data[:], res) err := proto.Unmarshal(packet.Data[:], res)
if err == nil { if err == nil {
rc.controlPacket(handler, res) // Wrap controlPacket in processUserCallback, since it calls out in many
// places, and wrapping the rest is harmless.
rc.processUserCallback(func() { rc.controlPacket(handler, res) })
} }
} else { } else {
// Let's check to see if we have defined this channel. // Let's check to see if we have defined this channel.
@ -199,11 +252,11 @@ func (rc *Connection) Process(handler Handler) error {
if len(packet.Data) == 0 { if len(packet.Data) == 0 {
rc.traceLog(fmt.Sprintf("removing channel %d", packet.Channel)) rc.traceLog(fmt.Sprintf("removing channel %d", packet.Channel))
rc.channelManager.RemoveChannel(packet.Channel) rc.channelManager.RemoveChannel(packet.Channel)
channel.Handler.Closed(utils.ChannelClosedByPeerError) rc.processUserCallback(func() { channel.Handler.Closed(utils.ChannelClosedByPeerError) })
} else { } else {
rc.traceLog(fmt.Sprintf("received packet on %v channel %d", channel.Handler.Type(), packet.Channel)) rc.traceLog(fmt.Sprintf("received packet on %v channel %d", channel.Handler.Type(), packet.Channel))
// Send The Ricochet Packet to the Handler // Send The Ricochet Packet to the Handler
channel.Handler.Packet(packet.Data[:]) rc.processUserCallback(func() { channel.Handler.Packet(packet.Data[:]) })
} }
} else { } else {
// When a non-zero packet is received for an unknown // When a non-zero packet is received for an unknown

View File

@ -64,13 +64,19 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string)
go rc.Process(echobot) go rc.Process(echobot)
if !known { if !known {
_, err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot}) err := rc.Do(func() error {
_, err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot})
return err
})
if err != nil { if err != nil {
log.Printf("could not contact %s", err) log.Printf("could not contact %s", err)
} }
} }
rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot}) rc.Do(func() error {
_, err := rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot})
return err
})
for { for {
message := <-echobot.messages message := <-echobot.messages
log.Printf("Received Message: %s", message) log.Printf("Received Message: %s", message)