forked from openprivacy/libricochet-go
Extract Authorization Logic / Refactoring OpenChannel Control Logic
This commit is contained in:
parent
4994e54025
commit
1e33c17ae3
|
@ -0,0 +1,28 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"github.com/s-rah/go-ricochet/utils"
|
||||
)
|
||||
|
||||
type AuthorizationManager struct {
|
||||
Authorizations map[string]bool
|
||||
}
|
||||
|
||||
// Init sets up an AuthorizationManager to be used.
|
||||
func (am *AuthorizationManager) Init() {
|
||||
am.Authorizations = make(map[string]bool)
|
||||
}
|
||||
|
||||
// AddAuthorization adds the string authz to the map of allowed authorizations
|
||||
func (am *AuthorizationManager) AddAuthorization(authz string) {
|
||||
am.Authorizations[authz] = true
|
||||
}
|
||||
|
||||
// Authorized returns no error in the case an authz type is authorized, error otherwise.
|
||||
func (am *AuthorizationManager) Authorized(authz string) error {
|
||||
_, authed := am.Authorizations[authz]
|
||||
if !authed {
|
||||
return utils.UnauthorizedActionError
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthorizationManager(t *testing.T) {
|
||||
am := new(AuthorizationManager)
|
||||
am.Init()
|
||||
am.AddAuthorization("test")
|
||||
if am.Authorized("test") != nil {
|
||||
t.Errorf("Authorized(test) should return nil, instead returned error: %v", am.Authorized("test"))
|
||||
}
|
||||
|
||||
if am.Authorized("not_authed") == nil {
|
||||
t.Errorf("Authorized(not_authed) should return error, instead returned nil: %v", am.Authorized("not_authed"))
|
||||
}
|
||||
}
|
|
@ -42,7 +42,7 @@ type Connection struct {
|
|||
|
||||
Conn io.ReadWriteCloser
|
||||
IsInbound bool
|
||||
Authentication map[string]bool
|
||||
am AuthorizationManager
|
||||
RemoteHostname string
|
||||
SupportChannels []string
|
||||
}
|
||||
|
@ -58,7 +58,8 @@ func (rc *Connection) init() {
|
|||
rc.unlockChannel = make(chan bool)
|
||||
rc.unlockResponseChannel = make(chan bool)
|
||||
|
||||
rc.Authentication = make(map[string]bool)
|
||||
rc.am.Init()
|
||||
rc.am.AddAuthorization("none")
|
||||
go rc.start()
|
||||
}
|
||||
|
||||
|
@ -204,15 +205,10 @@ func (rc *Connection) DoContext(ctx context.Context, do func(context.Context) er
|
|||
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) (*channels.Channel, error) {
|
||||
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
|
||||
|
||||
// Check that we have the authentication already
|
||||
if handler.RequiresAuthentication() != "none" {
|
||||
// Enforce Authentication Check.
|
||||
_, authed := rc.Authentication[handler.RequiresAuthentication()]
|
||||
if !authed {
|
||||
return nil, utils.UnauthorizedActionError
|
||||
}
|
||||
err := rc.am.Authorized(handler.RequiresAuthentication())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channel, err := rc.channelManager.OpenChannelRequest(handler)
|
||||
|
||||
if err != nil {
|
||||
|
@ -224,7 +220,7 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler)
|
|||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||
}
|
||||
channel.DelegateAuthorization = func() {
|
||||
rc.Authentication[handler.Type()] = true
|
||||
rc.am.AddAuthorization(handler.Type())
|
||||
}
|
||||
channel.CloseChannel = func() {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||
|
@ -379,58 +375,51 @@ func (rc *Connection) controlPacket(handler Handler, res *Protocol_Data_Control.
|
|||
opm := res.GetOpenChannel()
|
||||
chandler, err := handler.OnOpenChannelRequest(opm.GetChannelType())
|
||||
|
||||
if err != nil {
|
||||
|
||||
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnknownTypeError")
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that we have the authentication already
|
||||
if chandler.RequiresAuthentication() != "none" {
|
||||
rc.traceLog(fmt.Sprintf("channel %v requires authorization of type %v", chandler.Type(), chandler.RequiresAuthentication()))
|
||||
// Enforce Authentication Check.
|
||||
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
|
||||
if !authed {
|
||||
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "UnauthorizedError")
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
rc.traceLog(fmt.Sprintf("do not have required authorization to open channel type %v", chandler.Type()))
|
||||
return
|
||||
}
|
||||
rc.traceLog("succeeded authorization check")
|
||||
}
|
||||
|
||||
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
|
||||
|
||||
if err == nil {
|
||||
err = rc.am.Authorized(chandler.RequiresAuthentication())
|
||||
if err == nil {
|
||||
rc.traceLog("succeeded authorization check")
|
||||
channel, err := rc.channelManager.OpenChannelRequestFromPeer(opm.GetChannelIdentifier(), chandler)
|
||||
if err == nil {
|
||||
|
||||
channel.SendMessage = func(message []byte) {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||
}
|
||||
channel.DelegateAuthorization = func() {
|
||||
rc.Authentication[chandler.Type()] = true
|
||||
}
|
||||
channel.CloseChannel = func() {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
}
|
||||
channel.SendMessage = func(message []byte) {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||
}
|
||||
channel.DelegateAuthorization = func() {
|
||||
rc.am.AddAuthorization(chandler.Type())
|
||||
}
|
||||
channel.CloseChannel = func() {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
}
|
||||
|
||||
response, err := chandler.OpenInbound(channel, opm)
|
||||
if err == nil && channel.Pending == false {
|
||||
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
} else {
|
||||
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
|
||||
response, err := chandler.OpenInbound(channel, opm)
|
||||
if err == nil && channel.Pending == false {
|
||||
rc.traceLog(fmt.Sprintf("opening channel %v on %v", channel.Type, channel.ID))
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
} else {
|
||||
rc.traceLog(fmt.Sprintf("removing channel %v", channel.ID))
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
rc.SendRicochetPacket(rc.Conn, 0, []byte{})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Send Error Packet
|
||||
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), "GenericError")
|
||||
rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier()))
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
|
||||
}
|
||||
|
||||
errorText := "GenericError"
|
||||
switch err {
|
||||
case utils.UnknownChannelTypeError:
|
||||
errorText = "UnknownTypeError"
|
||||
case utils.UnauthorizedChannelTypeError:
|
||||
errorText = "UnauthorizedTypeError"
|
||||
}
|
||||
// Send Error Packet
|
||||
response := rc.messageBuilder.RejectOpenChannel(opm.GetChannelIdentifier(), errorText)
|
||||
rc.traceLog(fmt.Sprintf("sending reject open channel for %v", opm.GetChannelIdentifier()))
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
|
||||
} else if res.GetChannelResult() != nil {
|
||||
rc.ctrlChannel.ProcessChannelResult(res.GetChannelResult())
|
||||
} else if res.GetKeepAlive() != nil {
|
||||
|
|
Loading…
Reference in New Issue