From 1e33c17ae3f0c74672cc0333eb2b3a076e702e09 Mon Sep 17 00:00:00 2001 From: Sarah Jamie Lewis Date: Mon, 15 Jan 2018 13:07:54 -0500 Subject: [PATCH] Extract Authorization Logic / Refactoring OpenChannel Control Logic --- connection/authorizationmanager.go | 28 +++++++ connection/authorizationmanager_test.go | 18 +++++ connection/connection.go | 103 +++++++++++------------- 3 files changed, 92 insertions(+), 57 deletions(-) create mode 100644 connection/authorizationmanager.go create mode 100644 connection/authorizationmanager_test.go diff --git a/connection/authorizationmanager.go b/connection/authorizationmanager.go new file mode 100644 index 0000000..9775f7a --- /dev/null +++ b/connection/authorizationmanager.go @@ -0,0 +1,28 @@ +package connection + +import ( + "github.com/s-rah/go-ricochet/utils" +) + +type AuthorizationManager struct { + Authorizations map[string]bool +} + +// Init sets up an AuthorizationManager to be used. +func (am *AuthorizationManager) Init() { + am.Authorizations = make(map[string]bool) +} + +// AddAuthorization adds the string authz to the map of allowed authorizations +func (am *AuthorizationManager) AddAuthorization(authz string) { + am.Authorizations[authz] = true +} + +// Authorized returns no error in the case an authz type is authorized, error otherwise. +func (am *AuthorizationManager) Authorized(authz string) error { + _, authed := am.Authorizations[authz] + if !authed { + return utils.UnauthorizedActionError + } + return nil +} diff --git a/connection/authorizationmanager_test.go b/connection/authorizationmanager_test.go new file mode 100644 index 0000000..6521a4b --- /dev/null +++ b/connection/authorizationmanager_test.go @@ -0,0 +1,18 @@ +package connection + +import ( + "testing" +) + +func TestAuthorizationManager(t *testing.T) { + am := new(AuthorizationManager) + am.Init() + am.AddAuthorization("test") + if am.Authorized("test") != nil { + t.Errorf("Authorized(test) should return nil, instead returned error: %v", am.Authorized("test")) + } + + if am.Authorized("not_authed") == nil { + t.Errorf("Authorized(not_authed) should return error, instead returned nil: %v", am.Authorized("not_authed")) + } +} diff --git a/connection/connection.go b/connection/connection.go index 3a94861..392821e 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -42,7 +42,7 @@ type Connection struct { Conn io.ReadWriteCloser IsInbound bool - Authentication map[string]bool + am AuthorizationManager RemoteHostname string SupportChannels []string } @@ -58,7 +58,8 @@ func (rc *Connection) init() { rc.unlockChannel = make(chan bool) rc.unlockResponseChannel = make(chan bool) - rc.Authentication = make(map[string]bool) + rc.am.Init() + rc.am.AddAuthorization("none") go rc.start() } @@ -204,15 +205,10 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) { rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype)) - // Check that we have the authentication already - if handler.RequiresAuthentication() != "none" { - // Enforce Authentication Check. - _, authed := rc.Authentication[handler.RequiresAuthentication()] - if !authed { - return nil, utils.UnauthorizedActionError - } + err := rc.am.Authorized(handler.RequiresAuthentication()) + if err != nil { + return nil, err } - channel, err := rc.channelManager.OpenChannelRequest(handler) if err != nil { @@ -224,7 +220,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) rc.SendRicochetPacket(rc.Conn, channel.ID, message) } channel.DelegateAuthorization = func() { - rc.Authentication[handler.Type()] = true + rc.am.AddAuthorization(handler.Type()) } channel.CloseChannel = func() { rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) @@ -379,58 +375,51 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control. opm := res.GetOpenChannel() chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType()) - if err != nil { - - response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnknownTypeError") - rc.SendRicochetPacket(rc.Conn, 0, response) - return - } - // Check that we have the authentication already - if chandler.RequiresAuthentication() != "none" { - rc.traceLog(fmt.Sprintf("channel %v requires authorization of type %v", chandler.Type(), chandler.RequiresAuthentication())) - // Enforce Authentication Check. - _, authed := rc.Authentication[chandler.RequiresAuthentication()] - if !authed { - response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnauthorizedError") - rc.SendRicochetPacket(rc.Conn, 0, response) - rc.traceLog(fmt.Sprintf("do not have required authorization to open channel type %v", chandler.Type())) - return - } - rc.traceLog("succeeded authorization check") - } - - channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler) - if err == nil { + err = rc.am.Authorized(chandler.RequiresAuthentication()) + if err == nil { + rc.traceLog("succeeded authorization check") + channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler) + if err == nil { - channel.SendMessage = func(message []byte) { - rc.SendRicochetPacket(rc.Conn, channel.ID, message) - } - channel.DelegateAuthorization = func() { - rc.Authentication[chandler.Type()] = true - } - channel.CloseChannel = func() { - rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) - rc.channelManager.RemoveChannel(channel.ID) - } + channel.SendMessage = func(message []byte) { + rc.SendRicochetPacket(rc.Conn, channel.ID, message) + } + channel.DelegateAuthorization = func() { + rc.am.AddAuthorization(chandler.Type()) + } + channel.CloseChannel = func() { + rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) + rc.channelManager.RemoveChannel(channel.ID) + } - response, err := chandler.OpenInbound(channel, opm) - if err == nil && channel.Pending == false { - rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID)) - rc.SendRicochetPacket(rc.Conn, 0, response) - } else { - rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID)) - rc.channelManager.RemoveChannel(channel.ID) - rc.SendRicochetPacket(rc.Conn, 0, []byte{}) + response, err := chandler.OpenInbound(channel, opm) + if err == nil && channel.Pending == false { + rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID)) + rc.SendRicochetPacket(rc.Conn, 0, response) + } else { + rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID)) + rc.channelManager.RemoveChannel(channel.ID) + rc.SendRicochetPacket(rc.Conn, 0, []byte{}) + } + return + } } - } else { - // Send Error Packet - response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "GenericError") - rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier())) - rc.SendRicochetPacket(rc.Conn, 0, response) - } + + errorText := "GenericError" + switch err { + case utils.UnknownChannelTypeError: + errorText = "UnknownTypeError" + case utils.UnauthorizedChannelTypeError: + errorText = "UnauthorizedTypeError" + } + // Send Error Packet + response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), errorText) + rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier())) + rc.SendRicochetPacket(rc.Conn, 0, response) + } else if res.GetChannelResult() != nil { rc.ctrlChannel.ProcessChannelResult(res.GetChannelResult()) } else if res.GetKeepAlive() != nil {