From a04b3fe08b25bd7ae9a67518fe2745eb2dedc4fa Mon Sep 17 00:00:00 2001 From: Sarah Jamie Lewis Date: Sun, 14 Jan 2018 11:45:30 -0500 Subject: [PATCH] Moving ChannelResult processing to ControlChannel file and testing --- connection/connection.go | 27 ++++-------------- connection/control_channel.go | 32 +++++++++++++++++++-- connection/control_channel_test.go | 45 ++++++++++++++++++++++++++++-- utils/error.go | 1 + 4 files changed, 79 insertions(+), 26 deletions(-) diff --git a/connection/connection.go b/connection/connection.go index d73e66e..3a94861 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -2,7 +2,6 @@ package connection import ( "context" - "errors" "fmt" "github.com/golang/protobuf/proto" "github.com/s-rah/go-ricochet/channels" @@ -19,6 +18,7 @@ type Connection struct { utils.RicochetNetwork channelManager *ChannelManager + ctrlChannel ControlChannel // Ricochet Network Loop packetChannel chan utils.RicochetData @@ -70,6 +70,7 @@ func NewInboundConnection(conn io.ReadWriteCloser) *Connection { rc.IsInbound = true rc.init() rc.channelManager = NewServerChannelManager() + rc.ctrlChannel.Init(rc.channelManager) return rc } @@ -82,6 +83,7 @@ func NewOutboundConnection(conn io.ReadWriteCloser, remoteHostname string) *Conn rc.init() rc.RemoteHostname = remoteHostname rc.channelManager = NewClientChannelManager() + rc.ctrlChannel.Init(rc.channelManager) return rc } @@ -430,37 +432,20 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control. } } else if res.GetChannelResult() != nil { - cr := res.GetChannelResult() - id := cr.GetChannelIdentifier() - - channel, found := rc.channelManager.GetChannel(id) - - if !found { - rc.traceLog(fmt.Sprintf("channel result received for unknown channel: type:%s id:%v", channel.Type, id)) - return - } - - if cr.GetOpened() { - rc.traceLog(fmt.Sprintf("channel of type %v opened on %v", channel.Type, id)) - channel.Handler.OpenOutboundResult(nil, cr) - } else { - rc.traceLog(fmt.Sprintf("channel of type %v rejected on %v", channel.Type, id)) - channel.Handler.OpenOutboundResult(errors.New(cr.GetCommonError().String()), cr) - } - + rc.ctrlChannel.ProcessChannelResult(res.GetChannelResult()) } else if res.GetKeepAlive() != nil { // XXX Though not currently part of the protocol // We should likely put these calls behind // authentication. rc.traceLog("received keep alive packet") - respond, data := ProcessKeepAlive(res.GetKeepAlive()) + respond, data := rc.ctrlChannel.ProcessKeepAlive(res.GetKeepAlive()) if respond { rc.traceLog("sending keep alive response") rc.SendRicochetPacket(rc.Conn, 0, data) } } else if res.GetEnableFeatures() != nil { rc.traceLog("received enable features packet") - data := ProcessEnableFeatures(handler, res.GetEnableFeatures()) + data := rc.ctrlChannel.ProcessEnableFeatures(handler, res.GetEnableFeatures()) rc.traceLog(fmt.Sprintf("sending featured enabled: %v", data)) rc.SendRicochetPacket(rc.Conn, 0, data) } else if res.GetFeaturesEnabled() != nil { diff --git a/connection/control_channel.go b/connection/control_channel.go index 6ca2eae..2954f9e 100644 --- a/connection/control_channel.go +++ b/connection/control_channel.go @@ -1,12 +1,40 @@ package connection import ( + "errors" "github.com/s-rah/go-ricochet/utils" "github.com/s-rah/go-ricochet/wire/control" ) +type ControlChannel struct { + channelManager *ChannelManager +} + +func (ctrl *ControlChannel) Init(channelManager *ChannelManager) { + ctrl.channelManager = channelManager +} + +func (ctrl *ControlChannel) ProcessChannelResult(cr *Protocol_Data_Control.ChannelResult) (bool, error) { + id := cr.GetChannelIdentifier() + + channel, found := ctrl.channelManager.GetChannel(id) + + if !found { + return false, utils.UnexpectedChannelResultError + } + + if cr.GetOpened() { + //rc.traceLog(fmt.Sprintf("channel of type %v opened on %v", channel.Type, id)) + channel.Handler.OpenOutboundResult(nil, cr) + return true, nil + } + //rc.traceLog(fmt.Sprintf("channel of type %v rejected on %v", channel.Type, id)) + channel.Handler.OpenOutboundResult(errors.New(cr.GetCommonError().String()), cr) + return false, nil +} + // ProcessKeepAlive -func ProcessKeepAlive(ka *Protocol_Data_Control.KeepAlive) (bool, []byte) { +func (ctrl *ControlChannel) ProcessKeepAlive(ka *Protocol_Data_Control.KeepAlive) (bool, []byte) { if ka.GetResponseRequested() { messageBuilder := new(utils.MessageBuilder) return true, messageBuilder.KeepAlive(true) @@ -15,7 +43,7 @@ func ProcessKeepAlive(ka *Protocol_Data_Control.KeepAlive) (bool, []byte) { } // ProcessEnableFeatures correctly handles a features enabled packet -func ProcessEnableFeatures(handler Handler, ef *Protocol_Data_Control.EnableFeatures) []byte { +func (ctrl *ControlChannel) ProcessEnableFeatures(handler Handler, ef *Protocol_Data_Control.EnableFeatures) []byte { featuresToEnable := ef.GetFeature() supportChannels := handler.GetSupportedChannelTypes() result := []string{} diff --git a/connection/control_channel_test.go b/connection/control_channel_test.go index b36f869..ca872a0 100644 --- a/connection/control_channel_test.go +++ b/connection/control_channel_test.go @@ -2,6 +2,8 @@ package connection import ( "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/channels" + "github.com/s-rah/go-ricochet/utils" "github.com/s-rah/go-ricochet/wire/control" "testing" ) @@ -14,11 +16,46 @@ func (m *MockHandler) GetSupportedChannelTypes() []string { return []string{"im.ricochet.chat"} } +func TestChannelResultNotOpened(t *testing.T) { + ccm := NewClientChannelManager() + ctrlChannel := new(ControlChannel) + ctrlChannel.Init(ccm) + chatChannel := new(channels.ChatChannel) + _, err := ccm.OpenChannelRequestFromPeer(2, chatChannel) + + cr := &Protocol_Data_Control.ChannelResult{ + ChannelIdentifier: proto.Int32(2), + Opened: proto.Bool(false), + } + opened, err := ctrlChannel.ProcessChannelResult(cr) + if opened != false || err != nil { + t.Errorf("ProcessChannelResult should have resulted in n channel being opened, and no error %v %v", opened, err) + } +} + +func TestChannelResultError(t *testing.T) { + ccm := NewClientChannelManager() + ctrlChannel := new(ControlChannel) + ctrlChannel.Init(ccm) + chatChannel := new(channels.ChatChannel) + _, err := ccm.OpenChannelRequestFromPeer(2, chatChannel) + + cr := &Protocol_Data_Control.ChannelResult{ + ChannelIdentifier: proto.Int32(3), + Opened: proto.Bool(false), + } + opened, err := ctrlChannel.ProcessChannelResult(cr) + if opened != false || err != utils.UnexpectedChannelResultError { + t.Errorf("ProcessChannelResult should have resulted in n channel being opened, and an error %v %v", opened, err) + } +} + func TestKeepAliveNoResponse(t *testing.T) { + ctrlChannel := new(ControlChannel) ka := &Protocol_Data_Control.KeepAlive{ ResponseRequested: proto.Bool(false), } - respond, _ := ProcessKeepAlive(ka) + respond, _ := ctrlChannel.ProcessKeepAlive(ka) if respond == true { t.Errorf("KeepAlive process should have not needed a response %v %v", ka, respond) } @@ -28,7 +65,8 @@ func TestKeepAliveRequestResponse(t *testing.T) { ka := &Protocol_Data_Control.KeepAlive{ ResponseRequested: proto.Bool(true), } - respond, _ := ProcessKeepAlive(ka) + ctrlChannel := new(ControlChannel) + respond, _ := ctrlChannel.ProcessKeepAlive(ka) if respond == false { t.Errorf("KeepAlive process should have produced a response %v %v", ka, respond) } @@ -40,7 +78,8 @@ func TestEnableFeatures(t *testing.T) { ef := &Protocol_Data_Control.EnableFeatures{ Feature: features, } - raw := ProcessEnableFeatures(handler, ef) + ctrlChannel := new(ControlChannel) + raw := ctrlChannel.ProcessEnableFeatures(handler, ef) res := new(Protocol_Data_Control.Packet) err := proto.Unmarshal(raw, res) if err != nil || res.GetFeaturesEnabled() == nil { diff --git a/utils/error.go b/utils/error.go index d65fb0d..4eb2475 100644 --- a/utils/error.go +++ b/utils/error.go @@ -19,6 +19,7 @@ const ( UnknownChannelTypeError = Error("UnknownChannelTypeError") UnauthorizedChannelTypeError = Error("UnauthorizedChannelTypeError") + UnexpectedChannelResultError = Error("UnexpectedChannelResultError") // Timeout Errors ActionTimedOutError = Error("ActionTimedOutError")