Consolidating Channel Building / Teardown logic
This commit is contained in:
parent
1e33c17ae3
commit
9191b7530e
|
@ -204,37 +204,44 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er
|
|||
// channel was opened successfully, because channels open asynchronously).
|
||||
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
|
||||
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
|
||||
channel, err := rc.buildChannel(handler, rc.channelManager.OpenChannelRequest)
|
||||
if err == nil {
|
||||
response, err := handler.OpenOutbound(channel)
|
||||
return rc.handleChannelOpening(channel, response, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (rc *Connection) handleChannelOpening(channel *channels.Channel, response []byte, err error) (*channels.Channel, error) {
|
||||
if err == nil {
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
return channel, nil
|
||||
}
|
||||
rc.traceLog(fmt.Sprintf("failed to request open channel: %v", err))
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (rc *Connection) buildChannel(handler channels.Handler, openChannelFunc func(handler channels.Handler) (*channels.Channel, error)) (*channels.Channel, error) {
|
||||
|
||||
err := rc.am.Authorized(handler.RequiresAuthentication())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channel, err := rc.channelManager.OpenChannelRequest(handler)
|
||||
|
||||
if err != nil {
|
||||
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channel.SendMessage = func(message []byte) {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||
}
|
||||
channel.DelegateAuthorization = func() {
|
||||
rc.am.AddAuthorization(handler.Type())
|
||||
}
|
||||
channel.CloseChannel = func() {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
}
|
||||
response, err := handler.OpenOutbound(channel)
|
||||
if err == nil {
|
||||
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
} else {
|
||||
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
channel, err := openChannelFunc(handler)
|
||||
if err == nil {
|
||||
channel.SendMessage = func(message []byte) {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||
}
|
||||
channel.DelegateAuthorization = func() {
|
||||
rc.am.AddAuthorization(handler.Type())
|
||||
}
|
||||
channel.CloseChannel = func() {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
}
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
return channel, nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// processUserCallback should be used to wrap any calls into handlers or
|
||||
|
@ -374,38 +381,17 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.
|
|||
|
||||
opm := res.GetOpenChannel()
|
||||
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
|
||||
|
||||
// Check that we have the authentication already
|
||||
if err == nil {
|
||||
err = rc.am.Authorized(chandler.RequiresAuthentication())
|
||||
if err == nil {
|
||||
rc.traceLog("succeeded authorization check")
|
||||
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
|
||||
if err == nil {
|
||||
|
||||
channel.SendMessage = func(message []byte) {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||
}
|
||||
channel.DelegateAuthorization = func() {
|
||||
rc.am.AddAuthorization(chandler.Type())
|
||||
}
|
||||
channel.CloseChannel = func() {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
}
|
||||
|
||||
response, err := chandler.OpenInbound(channel, opm)
|
||||
if err == nil && channel.Pending == false {
|
||||
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
} else {
|
||||
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
|
||||
}
|
||||
return
|
||||
}
|
||||
openChannel := func(chandler channels.Handler) (*channels.Channel, error) {
|
||||
return rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
|
||||
}
|
||||
channel, err := rc.buildChannel(chandler, openChannel)
|
||||
response, err := chandler.OpenInbound(channel, opm)
|
||||
_, err = rc.handleChannelOpening(channel, response, err)
|
||||
if err != nil {
|
||||
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
errorText := "GenericError"
|
||||
|
|
Reference in New Issue