diff --git a/connection/connection.go b/connection/connection.go index 392821e..ac3d3f4 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -204,37 +204,44 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er // channel was opened successfully, because channels open asynchronously). func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) { rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype)) + channel, err := rc.buildChannel(handler, rc.channelManager.OpenChannelRequest) + if err == nil { + response, err := handler.OpenOutbound(channel) + return rc.handleChannelOpening(channel, response, err) + } + return nil, err +} + +func (rc *Connection) handleChannelOpening(channel *channels.Channel, response []byte, err error) (*channels.Channel, error) { + if err == nil { + rc.SendRicochetPacket(rc.Conn, 0, response) + return channel, nil + } + rc.traceLog(fmt.Sprintf("failed to request open channel: %v", err)) + rc.channelManager.RemoveChannel(channel.ID) + return nil, err +} + +func (rc *Connection) buildChannel(handler channels.Handler, openChannelFunc func(handler channels.Handler) (*channels.Channel, error)) (*channels.Channel, error) { err := rc.am.Authorized(handler.RequiresAuthentication()) - if err != nil { - return nil, err - } - channel, err := rc.channelManager.OpenChannelRequest(handler) - - if err != nil { - rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) - return nil, err - } - - channel.SendMessage = func(message []byte) { - rc.SendRicochetPacket(rc.Conn, channel.ID, message) - } - channel.DelegateAuthorization = func() { - rc.am.AddAuthorization(handler.Type()) - } - channel.CloseChannel = func() { - rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) - rc.channelManager.RemoveChannel(channel.ID) - } - response, err := handler.OpenOutbound(channel) if err == nil { - rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype)) - rc.SendRicochetPacket(rc.Conn, 0, response) - } else { - rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) - rc.channelManager.RemoveChannel(channel.ID) + channel, err := openChannelFunc(handler) + if err == nil { + channel.SendMessage = func(message []byte) { + rc.SendRicochetPacket(rc.Conn, channel.ID, message) + } + channel.DelegateAuthorization = func() { + rc.am.AddAuthorization(handler.Type()) + } + channel.CloseChannel = func() { + rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) + rc.channelManager.RemoveChannel(channel.ID) + } + return channel, nil + } } - return channel, nil + return nil, err } // processUserCallback should be used to wrap any calls into handlers or @@ -374,38 +381,17 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control. opm := res.GetOpenChannel() chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType()) - - // Check that we have the authentication already if err == nil { - err = rc.am.Authorized(chandler.RequiresAuthentication()) - if err == nil { - rc.traceLog("succeeded authorization check") - channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler) - if err == nil { - - channel.SendMessage = func(message []byte) { - rc.SendRicochetPacket(rc.Conn, channel.ID, message) - } - channel.DelegateAuthorization = func() { - rc.am.AddAuthorization(chandler.Type()) - } - channel.CloseChannel = func() { - rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) - rc.channelManager.RemoveChannel(channel.ID) - } - - response, err := chandler.OpenInbound(channel, opm) - if err == nil && channel.Pending == false { - rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID)) - rc.SendRicochetPacket(rc.Conn, 0, response) - } else { - rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID)) - rc.channelManager.RemoveChannel(channel.ID) - rc.SendRicochetPacket(rc.Conn, 0, []byte{}) - } - return - } + openChannel := func(chandler channels.Handler) (*channels.Channel, error) { + return rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler) } + channel, err := rc.buildChannel(chandler, openChannel) + response, err := chandler.OpenInbound(channel, opm) + _, err = rc.handleChannelOpening(channel, response, err) + if err != nil { + rc.SendRicochetPacket(rc.Conn, 0, []byte{}) + } + return } errorText := "GenericError"