Extract Authorization Logic / Refactoring OpenChannel Control Logic
This commit is contained in:
parent
4994e54025
commit
1e33c17ae3
|
@ -0,0 +1,28 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/s-rah/go-ricochet/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthorizationManager struct {
|
||||||
|
Authorizations map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init sets up an AuthorizationManager to be used.
|
||||||
|
func (am *AuthorizationManager) Init() {
|
||||||
|
am.Authorizations = make(map[string]bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAuthorization adds the string authz to the map of allowed authorizations
|
||||||
|
func (am *AuthorizationManager) AddAuthorization(authz string) {
|
||||||
|
am.Authorizations[authz] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorized returns no error in the case an authz type is authorized, error otherwise.
|
||||||
|
func (am *AuthorizationManager) Authorized(authz string) error {
|
||||||
|
_, authed := am.Authorizations[authz]
|
||||||
|
if !authed {
|
||||||
|
return utils.UnauthorizedActionError
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,18 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthorizationManager(t *testing.T) {
|
||||||
|
am := new(AuthorizationManager)
|
||||||
|
am.Init()
|
||||||
|
am.AddAuthorization("test")
|
||||||
|
if am.Authorized("test") != nil {
|
||||||
|
t.Errorf("Authorized(test) should return nil, instead returned error: %v", am.Authorized("test"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if am.Authorized("not_authed") == nil {
|
||||||
|
t.Errorf("Authorized(not_authed) should return error, instead returned nil: %v", am.Authorized("not_authed"))
|
||||||
|
}
|
||||||
|
}
|
|
@ -42,7 +42,7 @@ type Connection struct {
|
||||||
|
|
||||||
Conn io.ReadWriteCloser
|
Conn io.ReadWriteCloser
|
||||||
IsInbound bool
|
IsInbound bool
|
||||||
Authentication map[string]bool
|
am AuthorizationManager
|
||||||
RemoteHostname string
|
RemoteHostname string
|
||||||
SupportChannels []string
|
SupportChannels []string
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,8 @@ func (rc *Connection) init() {
|
||||||
rc.unlockChannel = make(chan bool)
|
rc.unlockChannel = make(chan bool)
|
||||||
rc.unlockResponseChannel = make(chan bool)
|
rc.unlockResponseChannel = make(chan bool)
|
||||||
|
|
||||||
rc.Authentication = make(map[string]bool)
|
rc.am.Init()
|
||||||
|
rc.am.AddAuthorization("none")
|
||||||
go rc.start()
|
go rc.start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -204,15 +205,10 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er
|
||||||
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
|
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
|
||||||
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
|
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
|
||||||
|
|
||||||
// Check that we have the authentication already
|
err := rc.am.Authorized(handler.RequiresAuthentication())
|
||||||
if handler.RequiresAuthentication() != "none" {
|
if err != nil {
|
||||||
// Enforce Authentication Check.
|
return nil, err
|
||||||
_, authed := rc.Authentication[handler.RequiresAuthentication()]
|
|
||||||
if !authed {
|
|
||||||
return nil, utils.UnauthorizedActionError
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
channel, err := rc.channelManager.OpenChannelRequest(handler)
|
channel, err := rc.channelManager.OpenChannelRequest(handler)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -224,7 +220,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler)
|
||||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||||
}
|
}
|
||||||
channel.DelegateAuthorization = func() {
|
channel.DelegateAuthorization = func() {
|
||||||
rc.Authentication[handler.Type()] = true
|
rc.am.AddAuthorization(handler.Type())
|
||||||
}
|
}
|
||||||
channel.CloseChannel = func() {
|
channel.CloseChannel = func() {
|
||||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||||
|
@ -379,58 +375,51 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.
|
||||||
opm := res.GetOpenChannel()
|
opm := res.GetOpenChannel()
|
||||||
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
|
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnknownTypeError")
|
|
||||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that we have the authentication already
|
// Check that we have the authentication already
|
||||||
if chandler.RequiresAuthentication() != "none" {
|
|
||||||
rc.traceLog(fmt.Sprintf("channel %v requires authorization of type %v", chandler.Type(), chandler.RequiresAuthentication()))
|
|
||||||
// Enforce Authentication Check.
|
|
||||||
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
|
|
||||||
if !authed {
|
|
||||||
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnauthorizedError")
|
|
||||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
|
||||||
rc.traceLog(fmt.Sprintf("do not have required authorization to open channel type %v", chandler.Type()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
rc.traceLog("succeeded authorization check")
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
err = rc.am.Authorized(chandler.RequiresAuthentication())
|
||||||
|
if err == nil {
|
||||||
|
rc.traceLog("succeeded authorization check")
|
||||||
|
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
|
||||||
|
if err == nil {
|
||||||
|
|
||||||
channel.SendMessage = func(message []byte) {
|
channel.SendMessage = func(message []byte) {
|
||||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||||
}
|
}
|
||||||
channel.DelegateAuthorization = func() {
|
channel.DelegateAuthorization = func() {
|
||||||
rc.Authentication[chandler.Type()] = true
|
rc.am.AddAuthorization(chandler.Type())
|
||||||
}
|
}
|
||||||
channel.CloseChannel = func() {
|
channel.CloseChannel = func() {
|
||||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||||
rc.channelManager.RemoveChannel(channel.ID)
|
rc.channelManager.RemoveChannel(channel.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := chandler.OpenInbound(channel, opm)
|
response, err := chandler.OpenInbound(channel, opm)
|
||||||
if err == nil && channel.Pending == false {
|
if err == nil && channel.Pending == false {
|
||||||
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
|
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
|
||||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||||
} else {
|
} else {
|
||||||
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
|
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
|
||||||
rc.channelManager.RemoveChannel(channel.ID)
|
rc.channelManager.RemoveChannel(channel.ID)
|
||||||
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
|
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Send Error Packet
|
|
||||||
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "GenericError")
|
|
||||||
rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier()))
|
|
||||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
errorText := "GenericError"
|
||||||
|
switch err {
|
||||||
|
case utils.UnknownChannelTypeError:
|
||||||
|
errorText = "UnknownTypeError"
|
||||||
|
case utils.UnauthorizedChannelTypeError:
|
||||||
|
errorText = "UnauthorizedTypeError"
|
||||||
|
}
|
||||||
|
// Send Error Packet
|
||||||
|
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), errorText)
|
||||||
|
rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier()))
|
||||||
|
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||||
|
|
||||||
} else if res.GetChannelResult() != nil {
|
} else if res.GetChannelResult() != nil {
|
||||||
rc.ctrlChannel.ProcessChannelResult(res.GetChannelResult())
|
rc.ctrlChannel.ProcessChannelResult(res.GetChannelResult())
|
||||||
} else if res.GetKeepAlive() != nil {
|
} else if res.GetKeepAlive() != nil {
|
||||||
|
|
Reference in New Issue