Consolidating Channel Building / Teardown logic

This commit is contained in:
Sarah Jamie Lewis 2018-01-16 11:53:34 -05:00
parent 1e33c17ae3
commit 9191b7530e
1 changed files with 43 additions and 57 deletions

View File

@ -204,37 +204,44 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er
// channel was opened successfully, because channels open asynchronously).
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
channel, err := rc.buildChannel(handler, rc.channelManager.OpenChannelRequest)
if err == nil {
response, err := handler.OpenOutbound(channel)
return rc.handleChannelOpening(channel, response, err)
}
return nil, err
}
func (rc *Connection) handleChannelOpening(channel *channels.Channel, response []byte, err error) (*channels.Channel, error) {
if err == nil {
rc.SendRicochetPacket(rc.Conn, 0, response)
return channel, nil
}
rc.traceLog(fmt.Sprintf("failed to request open channel: %v", err))
rc.channelManager.RemoveChannel(channel.ID)
return nil, err
}
func (rc *Connection) buildChannel(handler channels.Handler, openChannelFunc func(handler channels.Handler) (*channels.Channel, error)) (*channels.Channel, error) {
err := rc.am.Authorized(handler.RequiresAuthentication())
if err != nil {
return nil, err
}
channel, err := rc.channelManager.OpenChannelRequest(handler)
if err != nil {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return nil, err
}
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.am.AddAuthorization(handler.Type())
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := handler.OpenOutbound(channel)
if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID)
channel, err := openChannelFunc(handler)
if err == nil {
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.am.AddAuthorization(handler.Type())
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
return channel, nil
}
}
return channel, nil
return nil, err
}
// processUserCallback should be used to wrap any calls into handlers or
@ -374,38 +381,17 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.
opm := res.GetOpenChannel()
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
// Check that we have the authentication already
if err == nil {
err = rc.am.Authorized(chandler.RequiresAuthentication())
if err == nil {
rc.traceLog("succeeded authorization check")
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
if err == nil {
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.am.AddAuthorization(chandler.Type())
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := chandler.OpenInbound(channel, opm)
if err == nil && channel.Pending == false {
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
rc.channelManager.RemoveChannel(channel.ID)
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
}
return
}
openChannel := func(chandler channels.Handler) (*channels.Channel, error) {
return rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
}
channel, err := rc.buildChannel(chandler, openChannel)
response, err := chandler.OpenInbound(channel, opm)
_, err = rc.handleChannelOpening(channel, response, err)
if err != nil {
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
}
return
}
errorText := "GenericError"