Extract Authorization Logic / Refactoring OpenChannel Control Logic

This commit is contained in:
Sarah Jamie Lewis 2018-01-15 13:07:54 -05:00
parent 4994e54025
commit 1e33c17ae3
3 changed files with 92 additions and 57 deletions

View File

@ -0,0 +1,28 @@
package connection
import (
"github.com/s-rah/go-ricochet/utils"
)
type AuthorizationManager struct {
Authorizations map[string]bool
}
// Init sets up an AuthorizationManager to be used.
func (am *AuthorizationManager) Init() {
am.Authorizations = make(map[string]bool)
}
// AddAuthorization adds the string authz to the map of allowed authorizations
func (am *AuthorizationManager) AddAuthorization(authz string) {
am.Authorizations[authz] = true
}
// Authorized returns no error in the case an authz type is authorized, error otherwise.
func (am *AuthorizationManager) Authorized(authz string) error {
_, authed := am.Authorizations[authz]
if !authed {
return utils.UnauthorizedActionError
}
return nil
}

View File

@ -0,0 +1,18 @@
package connection
import (
"testing"
)
func TestAuthorizationManager(t *testing.T) {
am := new(AuthorizationManager)
am.Init()
am.AddAuthorization("test")
if am.Authorized("test") != nil {
t.Errorf("Authorized(test) should return nil, instead returned error: %v", am.Authorized("test"))
}
if am.Authorized("not_authed") == nil {
t.Errorf("Authorized(not_authed) should return error, instead returned nil: %v", am.Authorized("not_authed"))
}
}

View File

@ -42,7 +42,7 @@ type Connection struct {
Conn io.ReadWriteCloser
IsInbound bool
Authentication map[string]bool
am AuthorizationManager
RemoteHostname string
SupportChannels []string
}
@ -58,7 +58,8 @@ func (rc *Connection) init() {
rc.unlockChannel = make(chan bool)
rc.unlockResponseChannel = make(chan bool)
rc.Authentication = make(map[string]bool)
rc.am.Init()
rc.am.AddAuthorization("none")
go rc.start()
}
@ -204,15 +205,10 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
// Check that we have the authentication already
if handler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[handler.RequiresAuthentication()]
if !authed {
return nil, utils.UnauthorizedActionError
}
err := rc.am.Authorized(handler.RequiresAuthentication())
if err != nil {
return nil, err
}
channel, err := rc.channelManager.OpenChannelRequest(handler)
if err != nil {
@ -224,7 +220,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler)
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[handler.Type()] = true
rc.am.AddAuthorization(handler.Type())
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
@ -379,58 +375,51 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.
opm := res.GetOpenChannel()
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
if err != nil {
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnknownTypeError")
rc.SendRicochetPacket(rc.Conn, 0, response)
return
}
// Check that we have the authentication already
if chandler.RequiresAuthentication() != "none" {
rc.traceLog(fmt.Sprintf("channel %v requires authorization of type %v", chandler.Type(), chandler.RequiresAuthentication()))
// Enforce Authentication Check.
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
if !authed {
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnauthorizedError")
rc.SendRicochetPacket(rc.Conn, 0, response)
rc.traceLog(fmt.Sprintf("do not have required authorization to open channel type %v", chandler.Type()))
return
}
rc.traceLog("succeeded authorization check")
}
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
if err == nil {
err = rc.am.Authorized(chandler.RequiresAuthentication())
if err == nil {
rc.traceLog("succeeded authorization check")
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
if err == nil {
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[chandler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
channel.SendMessage = func(message []byte) {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.am.AddAuthorization(chandler.Type())
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := chandler.OpenInbound(channel, opm)
if err == nil && channel.Pending == false {
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
rc.channelManager.RemoveChannel(channel.ID)
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
response, err := chandler.OpenInbound(channel, opm)
if err == nil && channel.Pending == false {
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
rc.channelManager.RemoveChannel(channel.ID)
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
}
return
}
}
} else {
// Send Error Packet
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "GenericError")
rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier()))
rc.SendRicochetPacket(rc.Conn, 0, response)
}
errorText := "GenericError"
switch err {
case utils.UnknownChannelTypeError:
errorText = "UnknownTypeError"
case utils.UnauthorizedChannelTypeError:
errorText = "UnauthorizedTypeError"
}
// Send Error Packet
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), errorText)
rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier()))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else if res.GetChannelResult() != nil {
rc.ctrlChannel.ProcessChannelResult(res.GetChannelResult())
} else if res.GetKeepAlive() != nil {