Pass channel handler directly to RequestOpenChannel
RequestOpenChannel is the primary API to open a new outbound channel. It was written to take a connection.Handler and use OnOpenChannelRequest to get a channels.Handler to represent the new channel, which is the same path that inbound channels will take. Going through the global OnOpenChannelRequest method makes this much less flexible and prevents passing parameters to the new channel handler during creation. This also requires users of the API to know/find the connection handler, or worse, to boilerplate one into existence for their channel creation. Instead, I think this function should take a channels.Handler directly, so that the caller gets full control over the handler for their new channel. As part of that change, I've also moved the authentication logic in AutoConnectionHandler to be contained entirely within {In,Out}boundConnectionHandler.
This commit is contained in:
parent
1ed9265866
commit
d19102b257
|
@ -10,13 +10,6 @@ const (
|
|||
Outbound
|
||||
)
|
||||
|
||||
// AuthChannelResult captures the result of an authentication flow
|
||||
type AuthChannelResult struct {
|
||||
Hostname string
|
||||
Accepted bool
|
||||
IsKnownContact bool
|
||||
}
|
||||
|
||||
// Channel holds the state of a channel on an open connection
|
||||
type Channel struct {
|
||||
ID int32
|
||||
|
|
|
@ -20,28 +20,21 @@ const (
|
|||
|
||||
// HiddenServiceAuthChannel wraps implementation of im.ricochet.auth.hidden-service"
|
||||
type HiddenServiceAuthChannel struct {
|
||||
// Methods of Handler are called for events on this channel
|
||||
Handler AuthChannelHandler
|
||||
// PrivateKey must be set for client-side authentication channels
|
||||
PrivateKey *rsa.PrivateKey
|
||||
// Server Hostname must be set for client-side authentication channels
|
||||
ServerHostname string
|
||||
|
||||
// Callbacks
|
||||
ClientAuthResult func(accepted, isKnownContact bool)
|
||||
ServerAuthValid func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
|
||||
ServerAuthInvalid func(err error)
|
||||
|
||||
// Internal state
|
||||
clientCookie, serverCookie [16]byte
|
||||
channel *Channel
|
||||
}
|
||||
|
||||
// AuthChannelHandler ...
|
||||
type AuthChannelHandler interface {
|
||||
// Client
|
||||
ClientAuthResult(accepted bool, isKnownContact bool)
|
||||
|
||||
// Server
|
||||
ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
|
||||
ServerAuthInvalid(err error)
|
||||
}
|
||||
|
||||
// Type returns the type string for this channel, e.g. "im.ricochet.chat".
|
||||
func (ah *HiddenServiceAuthChannel) Type() string {
|
||||
return "im.ricochet.auth.hidden-service"
|
||||
|
@ -173,7 +166,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
|
|||
})
|
||||
|
||||
if err != nil {
|
||||
ah.Handler.ServerAuthInvalid(err)
|
||||
ah.ServerAuthInvalid(err)
|
||||
ah.channel.SendMessage([]byte{})
|
||||
return
|
||||
}
|
||||
|
@ -183,7 +176,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
|
|||
publicKey := rsa.PublicKey{}
|
||||
_, err = asn1.Unmarshal(res.GetProof().GetPublicKey(), &publicKey)
|
||||
if err != nil {
|
||||
ah.Handler.ServerAuthInvalid(err)
|
||||
ah.ServerAuthInvalid(err)
|
||||
ah.channel.SendMessage([]byte{})
|
||||
return
|
||||
}
|
||||
|
@ -194,7 +187,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
|
|||
|
||||
if err == nil {
|
||||
// Signature is Good
|
||||
accepted, isKnownContact := ah.Handler.ServerAuthValid(provisionalClientHostname, publicKey)
|
||||
accepted, isKnownContact := ah.ServerAuthValid(provisionalClientHostname, publicKey)
|
||||
|
||||
// Send Result
|
||||
messageBuilder := new(utils.MessageBuilder)
|
||||
|
@ -206,11 +199,13 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
|
|||
messageBuilder := new(utils.MessageBuilder)
|
||||
result := messageBuilder.AuthResult(false, false)
|
||||
ah.channel.SendMessage(result)
|
||||
ah.Handler.ServerAuthInvalid(err)
|
||||
ah.ServerAuthInvalid(err)
|
||||
}
|
||||
|
||||
} else if res.GetResult() != nil && ah.channel.Direction == Outbound {
|
||||
ah.Handler.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
|
||||
if ah.ClientAuthResult != nil {
|
||||
ah.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
|
||||
}
|
||||
if res.GetResult().GetAccepted() {
|
||||
ah.channel.DelegateAuthorization()
|
||||
}
|
||||
|
|
|
@ -110,34 +110,16 @@ func TestAuthenticationOpenOutbound(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
type SimpleTestAuthHandler struct {
|
||||
}
|
||||
|
||||
// Client
|
||||
func (stah *SimpleTestAuthHandler) ClientAuthResult(accepted bool, isKnownContact bool) {
|
||||
|
||||
}
|
||||
|
||||
// Server
|
||||
func (stah *SimpleTestAuthHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
func (stah *SimpleTestAuthHandler) ServerAuthInvalid(err error) {
|
||||
|
||||
}
|
||||
|
||||
func TestAuthenticationOpenOutboundResult(t *testing.T) {
|
||||
|
||||
privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key")
|
||||
|
||||
authHandlerA := new(HiddenServiceAuthChannel)
|
||||
authHandlerB := new(HiddenServiceAuthChannel)
|
||||
simpleTestAuthHandler := new(SimpleTestAuthHandler)
|
||||
|
||||
authHandlerA.ServerHostname = "kwke2hntvyfqm7dr"
|
||||
authHandlerA.PrivateKey = privateKey
|
||||
authHandlerA.Handler = simpleTestAuthHandler
|
||||
authHandlerA.ClientAuthResult = func(accepted, known bool) {}
|
||||
channelA := Channel{ID: 1, Direction: Outbound}
|
||||
channelA.SendMessage = func(message []byte) {
|
||||
authHandlerB.Packet(message)
|
||||
|
@ -150,7 +132,8 @@ func TestAuthenticationOpenOutboundResult(t *testing.T) {
|
|||
|
||||
authHandlerB.ServerHostname = "kwke2hntvyfqm7dr"
|
||||
authHandlerB.PrivateKey = privateKey
|
||||
authHandlerB.Handler = simpleTestAuthHandler
|
||||
authHandlerB.ServerAuthValid = func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { return true, true }
|
||||
authHandlerB.ServerAuthInvalid = func(err error) { t.Error("server received invalid auth") }
|
||||
channelB := Channel{ID: 1, Direction: Inbound}
|
||||
channelB.SendMessage = func(message []byte) {
|
||||
authHandlerA.Packet(message)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"github.com/s-rah/go-ricochet/channels"
|
||||
"github.com/s-rah/go-ricochet/utils"
|
||||
)
|
||||
|
@ -16,30 +15,14 @@ import (
|
|||
// AutoConnectionHandler's behavior (such as adding new channel types, or reacting
|
||||
// to connection close events), this type can be embedded in the type that it serves.
|
||||
type AutoConnectionHandler struct {
|
||||
handlerMap map[string]func() channels.Handler
|
||||
connection *Connection
|
||||
authResultChannel chan channels.AuthChannelResult
|
||||
sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
|
||||
handlerMap map[string]func() channels.Handler
|
||||
connection *Connection
|
||||
}
|
||||
|
||||
// Init ...
|
||||
// TODO: Split this into client and server init
|
||||
func (ach *AutoConnectionHandler) Init(privateKey *rsa.PrivateKey, serverHostname string) {
|
||||
|
||||
func (ach *AutoConnectionHandler) Init() {
|
||||
ach.handlerMap = make(map[string]func() channels.Handler)
|
||||
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler {
|
||||
hsau := new(channels.HiddenServiceAuthChannel)
|
||||
hsau.PrivateKey = privateKey
|
||||
hsau.Handler = ach
|
||||
hsau.ServerHostname = serverHostname
|
||||
return hsau
|
||||
})
|
||||
ach.authResultChannel = make(chan channels.AuthChannelResult)
|
||||
}
|
||||
|
||||
// SetServerAuthHandler ...
|
||||
func (ach *AutoConnectionHandler) SetServerAuthHandler(sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)) {
|
||||
ach.sach = sach
|
||||
}
|
||||
|
||||
// OnReady ...
|
||||
|
@ -51,29 +34,6 @@ func (ach *AutoConnectionHandler) OnReady(oc *Connection) {
|
|||
func (ach *AutoConnectionHandler) OnClosed(err error) {
|
||||
}
|
||||
|
||||
// WaitForAuthenticationEvent ...
|
||||
func (ach *AutoConnectionHandler) WaitForAuthenticationEvent() channels.AuthChannelResult {
|
||||
return <-ach.authResultChannel
|
||||
}
|
||||
|
||||
// ClientAuthResult ...
|
||||
func (ach *AutoConnectionHandler) ClientAuthResult(accepted bool, isKnownContact bool) {
|
||||
ach.authResultChannel <- channels.AuthChannelResult{Accepted: accepted, IsKnownContact: isKnownContact}
|
||||
}
|
||||
|
||||
// ServerAuthValid ...
|
||||
func (ach *AutoConnectionHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
|
||||
// Do something
|
||||
accepted, isKnownContact := ach.sach(hostname, publicKey)
|
||||
ach.authResultChannel <- channels.AuthChannelResult{Hostname: hostname, Accepted: accepted, IsKnownContact: isKnownContact}
|
||||
return accepted, isKnownContact
|
||||
}
|
||||
|
||||
// ServerAuthInvalid ...
|
||||
func (ach *AutoConnectionHandler) ServerAuthInvalid(err error) {
|
||||
ach.authResultChannel <- channels.AuthChannelResult{Accepted: false, IsKnownContact: false}
|
||||
}
|
||||
|
||||
// RegisterChannelHandler ...
|
||||
func (ach *AutoConnectionHandler) RegisterChannelHandler(ctype string, handler func() channels.Handler) {
|
||||
_, exists := ach.handlerMap[ctype]
|
||||
|
|
|
@ -2,6 +2,7 @@ package connection
|
|||
|
||||
import (
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/s-rah/go-ricochet/channels"
|
||||
"github.com/s-rah/go-ricochet/utils"
|
||||
"github.com/s-rah/go-ricochet/wire/control"
|
||||
"testing"
|
||||
|
@ -10,9 +11,10 @@ import (
|
|||
// Test sending valid packets
|
||||
func TestInit(t *testing.T) {
|
||||
ach := new(AutoConnectionHandler)
|
||||
privateKey, err := utils.LoadPrivateKeyFromFile("../testing/private_key")
|
||||
|
||||
ach.Init(privateKey, "")
|
||||
ach.Init()
|
||||
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler {
|
||||
return &channels.HiddenServiceAuthChannel{}
|
||||
})
|
||||
|
||||
// Construct the Open Authentication Channel Message
|
||||
messageBuilder := new(utils.MessageBuilder)
|
||||
|
|
|
@ -106,32 +106,25 @@ func (rc *Connection) Do(do func() error) error {
|
|||
}
|
||||
|
||||
// RequestOpenChannel sends an OpenChannel message to the remote client.
|
||||
// and error is returned only if the requirements for opening this channel
|
||||
// are not met on the local side (a nill error return does not mean the
|
||||
// channel was opened successfully)
|
||||
func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error {
|
||||
// An error is returned only if the requirements for opening this channel
|
||||
// are not met on the local side (a nil error return does not mean the
|
||||
// channel was opened successfully, because channels open asynchronously).
|
||||
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) error {
|
||||
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
|
||||
return rc.Do(func() error {
|
||||
chandler, err := handler.OnOpenChannelRequest(ctype)
|
||||
|
||||
if err != nil {
|
||||
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Check that we have the authentication already
|
||||
if chandler.RequiresAuthentication() != "none" {
|
||||
if handler.RequiresAuthentication() != "none" {
|
||||
// Enforce Authentication Check.
|
||||
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
|
||||
_, authed := rc.Authentication[handler.RequiresAuthentication()]
|
||||
if !authed {
|
||||
return utils.UnauthorizedActionError
|
||||
}
|
||||
}
|
||||
|
||||
channel, err := rc.channelManager.OpenChannelRequest(chandler)
|
||||
channel, err := rc.channelManager.OpenChannelRequest(handler)
|
||||
|
||||
if err != nil {
|
||||
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err))
|
||||
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -139,18 +132,18 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error {
|
|||
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
|
||||
}
|
||||
channel.DelegateAuthorization = func() {
|
||||
rc.Authentication[chandler.Type()] = true
|
||||
rc.Authentication[handler.Type()] = true
|
||||
}
|
||||
channel.CloseChannel = func() {
|
||||
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
}
|
||||
response, err := chandler.OpenOutbound(channel)
|
||||
response, err := handler.OpenOutbound(channel)
|
||||
if err == nil {
|
||||
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
|
||||
rc.SendRicochetPacket(rc.Conn, 0, response)
|
||||
} else {
|
||||
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err))
|
||||
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
|
||||
rc.channelManager.RemoveChannel(channel.ID)
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -39,15 +39,32 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private
|
|||
return utils.PrivateKeyNotSetError
|
||||
}
|
||||
|
||||
ach := new(AutoConnectionHandler)
|
||||
ach.Init(privateKey, ich.connection.RemoteHostname)
|
||||
ach.SetServerAuthHandler(sach)
|
||||
var authAllowed, authKnown bool
|
||||
var authHostname string
|
||||
|
||||
var authResult channels.AuthChannelResult
|
||||
go func() {
|
||||
authResult = ach.WaitForAuthenticationEvent()
|
||||
onAuthValid := func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
|
||||
authAllowed, authKnown = sach(hostname, publicKey)
|
||||
if authAllowed {
|
||||
authHostname = hostname
|
||||
}
|
||||
ich.connection.Break()
|
||||
}()
|
||||
return authAllowed, authKnown
|
||||
}
|
||||
onAuthInvalid := func(err error) {
|
||||
// err is ignored at the moment
|
||||
ich.connection.Break()
|
||||
}
|
||||
|
||||
ach := new(AutoConnectionHandler)
|
||||
ach.Init()
|
||||
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service",
|
||||
func() channels.Handler {
|
||||
return &channels.HiddenServiceAuthChannel{
|
||||
PrivateKey: privateKey,
|
||||
ServerAuthValid: onAuthValid,
|
||||
ServerAuthInvalid: onAuthInvalid,
|
||||
}
|
||||
})
|
||||
|
||||
policy := policies.UnknownPurposeTimeout
|
||||
err := policy.ExecuteAction(func() error {
|
||||
|
@ -55,8 +72,8 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private
|
|||
})
|
||||
|
||||
if err == nil {
|
||||
if authResult.Accepted == true {
|
||||
ich.connection.RemoteHostname = authResult.Hostname
|
||||
if authAllowed == true {
|
||||
ich.connection.RemoteHostname = authHostname
|
||||
return nil
|
||||
}
|
||||
return utils.ClientFailedToAuthenticateError
|
||||
|
|
|
@ -38,26 +38,34 @@ func (och *OutboundConnectionHandler) ProcessAuthAsClient(privateKey *rsa.Privat
|
|||
}
|
||||
|
||||
ach := new(AutoConnectionHandler)
|
||||
ach.Init(privateKey, och.connection.RemoteHostname)
|
||||
ach.Init()
|
||||
|
||||
var result channels.AuthChannelResult
|
||||
go func() {
|
||||
err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service", ach)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
result = ach.WaitForAuthenticationEvent()
|
||||
var accepted, isKnownContact bool
|
||||
authCallback := func(accept, known bool) {
|
||||
accepted = accept
|
||||
isKnownContact = known
|
||||
// Cause the Process() call below to return
|
||||
och.connection.Break()
|
||||
}()
|
||||
}
|
||||
|
||||
err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service",
|
||||
&channels.HiddenServiceAuthChannel{
|
||||
PrivateKey: privateKey,
|
||||
ServerHostname: och.connection.RemoteHostname,
|
||||
ClientAuthResult: authCallback,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
policy := policies.UnknownPurposeTimeout
|
||||
err := policy.ExecuteAction(func() error {
|
||||
err = policy.ExecuteAction(func() error {
|
||||
return och.connection.Process(ach)
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
if result.Accepted == true {
|
||||
return result.IsKnownContact, nil
|
||||
if accepted == true {
|
||||
return isKnownContact, nil
|
||||
}
|
||||
}
|
||||
return false, utils.ServerRejectedClientConnectionError
|
||||
|
|
|
@ -45,7 +45,7 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string)
|
|||
privateKey, _ := utils.LoadPrivateKeyFromFile(privateKeyFile)
|
||||
echobot.messages = make(chan string)
|
||||
|
||||
echobot.Init(privateKey, hostname)
|
||||
echobot.Init()
|
||||
echobot.RegisterChannelHandler("im.ricochet.contact.request", func() channels.Handler {
|
||||
contact := new(channels.ContactRequestChannel)
|
||||
contact.Handler = echobot
|
||||
|
@ -64,13 +64,13 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string)
|
|||
go rc.Process(echobot)
|
||||
|
||||
if !known {
|
||||
err := rc.RequestOpenChannel("im.ricochet.contact.request", echobot)
|
||||
err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot})
|
||||
if err != nil {
|
||||
log.Printf("could not contact %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
rc.RequestOpenChannel("im.ricochet.chat", echobot)
|
||||
rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot})
|
||||
for {
|
||||
message := <-echobot.messages
|
||||
log.Printf("Received Message: %s", message)
|
||||
|
|
Reference in New Issue