Consolidating Channel Building / Teardown logic
This commit is contained in:
parent
1e33c17ae3
commit
9191b7530e
|
@ -204,37 +204,44 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er
|
||||||
// channel was opened successfully, because channels open asynchronously).
|
// channel was opened successfully, because channels open asynchronously).
|
||||||
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
|
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
|
||||||
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
|
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
|
||||||
|
channel, err := rc.buildChannel(handler, rc.channelManager.OpenChannelRequest)
|
||||||
|
if err == nil {
|
||||||
|
response, err := handler.OpenOutbound(channel)
|
||||||
|
return rc.handleChannelOpening(channel, response, err)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *Connection) handleChannelOpening(channel *channels.Channel, response []byte, err error) (*channels.Channel, error) {
|
||||||
|
if err == nil {
|
||||||
|
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||||
|
return channel, nil
|
||||||
|
}
|
||||||
|
rc.traceLog(fmt.Sprintf("failed to request open channel: %v", err))
|
||||||
|
rc.channelManager.RemoveChannel(channel.ID)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *Connection) buildChannel(handler channels.Handler, openChannelFunc func(handler channels.Handler) (*channels.Channel, error)) (*channels.Channel, error) {
|
||||||
|
|
||||||
err := rc.am.Authorized(handler.RequiresAuthentication())
|
err := rc.am.Authorized(handler.RequiresAuthentication())
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
channel, err := rc.channelManager.OpenChannelRequest(handler)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
channel.SendMessage = func(message []byte) {
|
|
||||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
|
||||||
}
|
|
||||||
channel.DelegateAuthorization = func() {
|
|
||||||
rc.am.AddAuthorization(handler.Type())
|
|
||||||
}
|
|
||||||
channel.CloseChannel = func() {
|
|
||||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
|
||||||
rc.channelManager.RemoveChannel(channel.ID)
|
|
||||||
}
|
|
||||||
response, err := handler.OpenOutbound(channel)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
|
channel, err := openChannelFunc(handler)
|
||||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
if err == nil {
|
||||||
} else {
|
channel.SendMessage = func(message []byte) {
|
||||||
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
|
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||||
rc.channelManager.RemoveChannel(channel.ID)
|
}
|
||||||
|
channel.DelegateAuthorization = func() {
|
||||||
|
rc.am.AddAuthorization(handler.Type())
|
||||||
|
}
|
||||||
|
channel.CloseChannel = func() {
|
||||||
|
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||||
|
rc.channelManager.RemoveChannel(channel.ID)
|
||||||
|
}
|
||||||
|
return channel, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return channel, nil
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// processUserCallback should be used to wrap any calls into handlers or
|
// processUserCallback should be used to wrap any calls into handlers or
|
||||||
|
@ -374,38 +381,17 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.
|
||||||
|
|
||||||
opm := res.GetOpenChannel()
|
opm := res.GetOpenChannel()
|
||||||
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
|
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
|
||||||
|
|
||||||
// Check that we have the authentication already
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = rc.am.Authorized(chandler.RequiresAuthentication())
|
openChannel := func(chandler channels.Handler) (*channels.Channel, error) {
|
||||||
if err == nil {
|
return rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
|
||||||
rc.traceLog("succeeded authorization check")
|
|
||||||
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
|
|
||||||
if err == nil {
|
|
||||||
|
|
||||||
channel.SendMessage = func(message []byte) {
|
|
||||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
|
||||||
}
|
|
||||||
channel.DelegateAuthorization = func() {
|
|
||||||
rc.am.AddAuthorization(chandler.Type())
|
|
||||||
}
|
|
||||||
channel.CloseChannel = func() {
|
|
||||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
|
||||||
rc.channelManager.RemoveChannel(channel.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := chandler.OpenInbound(channel, opm)
|
|
||||||
if err == nil && channel.Pending == false {
|
|
||||||
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
|
|
||||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
|
||||||
} else {
|
|
||||||
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
|
|
||||||
rc.channelManager.RemoveChannel(channel.ID)
|
|
||||||
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
channel, err := rc.buildChannel(chandler, openChannel)
|
||||||
|
response, err := chandler.OpenInbound(channel, opm)
|
||||||
|
_, err = rc.handleChannelOpening(channel, response, err)
|
||||||
|
if err != nil {
|
||||||
|
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
errorText := "GenericError"
|
errorText := "GenericError"
|
||||||
|
|
Reference in New Issue