Pass channel handler directly to RequestOpenChannel

RequestOpenChannel is the primary API to open a new outbound channel. It
was written to take a connection.Handler and use OnOpenChannelRequest
to get a channels.Handler to represent the new channel, which is the
same path that inbound channels will take.

Going through the global OnOpenChannelRequest method makes this much
less flexible and prevents passing parameters to the new channel handler
during creation.  This also requires users of the API to know/find the
connection handler, or worse, to boilerplate one into existence for their
channel creation.

Instead, I think this function should take a channels.Handler directly,
so that the caller gets full control over the handler for their new
channel.

As part of that change, I've also moved the authentication logic in
AutoConnectionHandler to be contained entirely within
{In,Out}boundConnectionHandler.
This commit is contained in:
John Brooks 2017-09-16 15:46:28 +02:00 committed by Sarah Jamie Lewis
parent 1ed9265866
commit d19102b257
9 changed files with 83 additions and 132 deletions

View File

@ -10,13 +10,6 @@ const (
Outbound Outbound
) )
// AuthChannelResult captures the result of an authentication flow
type AuthChannelResult struct {
Hostname string
Accepted bool
IsKnownContact bool
}
// Channel holds the state of a channel on an open connection // Channel holds the state of a channel on an open connection
type Channel struct { type Channel struct {
ID int32 ID int32

View File

@ -20,28 +20,21 @@ const (
// HiddenServiceAuthChannel wraps implementation of im.ricochet.auth.hidden-service" // HiddenServiceAuthChannel wraps implementation of im.ricochet.auth.hidden-service"
type HiddenServiceAuthChannel struct { type HiddenServiceAuthChannel struct {
// Methods of Handler are called for events on this channel
Handler AuthChannelHandler
// PrivateKey must be set for client-side authentication channels // PrivateKey must be set for client-side authentication channels
PrivateKey *rsa.PrivateKey PrivateKey *rsa.PrivateKey
// Server Hostname must be set for client-side authentication channels // Server Hostname must be set for client-side authentication channels
ServerHostname string ServerHostname string
// Callbacks
ClientAuthResult func(accepted, isKnownContact bool)
ServerAuthValid func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
ServerAuthInvalid func(err error)
// Internal state // Internal state
clientCookie, serverCookie [16]byte clientCookie, serverCookie [16]byte
channel *Channel channel *Channel
} }
// AuthChannelHandler ...
type AuthChannelHandler interface {
// Client
ClientAuthResult(accepted bool, isKnownContact bool)
// Server
ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
ServerAuthInvalid(err error)
}
// Type returns the type string for this channel, e.g. "im.ricochet.chat". // Type returns the type string for this channel, e.g. "im.ricochet.chat".
func (ah *HiddenServiceAuthChannel) Type() string { func (ah *HiddenServiceAuthChannel) Type() string {
return "im.ricochet.auth.hidden-service" return "im.ricochet.auth.hidden-service"
@ -173,7 +166,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
}) })
if err != nil { if err != nil {
ah.Handler.ServerAuthInvalid(err) ah.ServerAuthInvalid(err)
ah.channel.SendMessage([]byte{}) ah.channel.SendMessage([]byte{})
return return
} }
@ -183,7 +176,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
publicKey := rsa.PublicKey{} publicKey := rsa.PublicKey{}
_, err = asn1.Unmarshal(res.GetProof().GetPublicKey(), &publicKey) _, err = asn1.Unmarshal(res.GetProof().GetPublicKey(), &publicKey)
if err != nil { if err != nil {
ah.Handler.ServerAuthInvalid(err) ah.ServerAuthInvalid(err)
ah.channel.SendMessage([]byte{}) ah.channel.SendMessage([]byte{})
return return
} }
@ -194,7 +187,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
if err == nil { if err == nil {
// Signature is Good // Signature is Good
accepted, isKnownContact := ah.Handler.ServerAuthValid(provisionalClientHostname, publicKey) accepted, isKnownContact := ah.ServerAuthValid(provisionalClientHostname, publicKey)
// Send Result // Send Result
messageBuilder := new(utils.MessageBuilder) messageBuilder := new(utils.MessageBuilder)
@ -206,11 +199,13 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
messageBuilder := new(utils.MessageBuilder) messageBuilder := new(utils.MessageBuilder)
result := messageBuilder.AuthResult(false, false) result := messageBuilder.AuthResult(false, false)
ah.channel.SendMessage(result) ah.channel.SendMessage(result)
ah.Handler.ServerAuthInvalid(err) ah.ServerAuthInvalid(err)
} }
} else if res.GetResult() != nil && ah.channel.Direction == Outbound { } else if res.GetResult() != nil && ah.channel.Direction == Outbound {
ah.Handler.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) if ah.ClientAuthResult != nil {
ah.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
}
if res.GetResult().GetAccepted() { if res.GetResult().GetAccepted() {
ah.channel.DelegateAuthorization() ah.channel.DelegateAuthorization()
} }

View File

@ -110,34 +110,16 @@ func TestAuthenticationOpenOutbound(t *testing.T) {
} }
type SimpleTestAuthHandler struct {
}
// Client
func (stah *SimpleTestAuthHandler) ClientAuthResult(accepted bool, isKnownContact bool) {
}
// Server
func (stah *SimpleTestAuthHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
return true, true
}
func (stah *SimpleTestAuthHandler) ServerAuthInvalid(err error) {
}
func TestAuthenticationOpenOutboundResult(t *testing.T) { func TestAuthenticationOpenOutboundResult(t *testing.T) {
privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key") privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key")
authHandlerA := new(HiddenServiceAuthChannel) authHandlerA := new(HiddenServiceAuthChannel)
authHandlerB := new(HiddenServiceAuthChannel) authHandlerB := new(HiddenServiceAuthChannel)
simpleTestAuthHandler := new(SimpleTestAuthHandler)
authHandlerA.ServerHostname = "kwke2hntvyfqm7dr" authHandlerA.ServerHostname = "kwke2hntvyfqm7dr"
authHandlerA.PrivateKey = privateKey authHandlerA.PrivateKey = privateKey
authHandlerA.Handler = simpleTestAuthHandler authHandlerA.ClientAuthResult = func(accepted, known bool) {}
channelA := Channel{ID: 1, Direction: Outbound} channelA := Channel{ID: 1, Direction: Outbound}
channelA.SendMessage = func(message []byte) { channelA.SendMessage = func(message []byte) {
authHandlerB.Packet(message) authHandlerB.Packet(message)
@ -150,7 +132,8 @@ func TestAuthenticationOpenOutboundResult(t *testing.T) {
authHandlerB.ServerHostname = "kwke2hntvyfqm7dr" authHandlerB.ServerHostname = "kwke2hntvyfqm7dr"
authHandlerB.PrivateKey = privateKey authHandlerB.PrivateKey = privateKey
authHandlerB.Handler = simpleTestAuthHandler authHandlerB.ServerAuthValid = func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { return true, true }
authHandlerB.ServerAuthInvalid = func(err error) { t.Error("server received invalid auth") }
channelB := Channel{ID: 1, Direction: Inbound} channelB := Channel{ID: 1, Direction: Inbound}
channelB.SendMessage = func(message []byte) { channelB.SendMessage = func(message []byte) {
authHandlerA.Packet(message) authHandlerA.Packet(message)

View File

@ -1,7 +1,6 @@
package connection package connection
import ( import (
"crypto/rsa"
"github.com/s-rah/go-ricochet/channels" "github.com/s-rah/go-ricochet/channels"
"github.com/s-rah/go-ricochet/utils" "github.com/s-rah/go-ricochet/utils"
) )
@ -16,30 +15,14 @@ import (
// AutoConnectionHandler's behavior (such as adding new channel types, or reacting // AutoConnectionHandler's behavior (such as adding new channel types, or reacting
// to connection close events), this type can be embedded in the type that it serves. // to connection close events), this type can be embedded in the type that it serves.
type AutoConnectionHandler struct { type AutoConnectionHandler struct {
handlerMap map[string]func() channels.Handler handlerMap map[string]func() channels.Handler
connection *Connection connection *Connection
authResultChannel chan channels.AuthChannelResult
sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
} }
// Init ... // Init ...
// TODO: Split this into client and server init // TODO: Split this into client and server init
func (ach *AutoConnectionHandler) Init(privateKey *rsa.PrivateKey, serverHostname string) { func (ach *AutoConnectionHandler) Init() {
ach.handlerMap = make(map[string]func() channels.Handler) ach.handlerMap = make(map[string]func() channels.Handler)
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler {
hsau := new(channels.HiddenServiceAuthChannel)
hsau.PrivateKey = privateKey
hsau.Handler = ach
hsau.ServerHostname = serverHostname
return hsau
})
ach.authResultChannel = make(chan channels.AuthChannelResult)
}
// SetServerAuthHandler ...
func (ach *AutoConnectionHandler) SetServerAuthHandler(sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)) {
ach.sach = sach
} }
// OnReady ... // OnReady ...
@ -51,29 +34,6 @@ func (ach *AutoConnectionHandler) OnReady(oc *Connection) {
func (ach *AutoConnectionHandler) OnClosed(err error) { func (ach *AutoConnectionHandler) OnClosed(err error) {
} }
// WaitForAuthenticationEvent ...
func (ach *AutoConnectionHandler) WaitForAuthenticationEvent() channels.AuthChannelResult {
return <-ach.authResultChannel
}
// ClientAuthResult ...
func (ach *AutoConnectionHandler) ClientAuthResult(accepted bool, isKnownContact bool) {
ach.authResultChannel <- channels.AuthChannelResult{Accepted: accepted, IsKnownContact: isKnownContact}
}
// ServerAuthValid ...
func (ach *AutoConnectionHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
// Do something
accepted, isKnownContact := ach.sach(hostname, publicKey)
ach.authResultChannel <- channels.AuthChannelResult{Hostname: hostname, Accepted: accepted, IsKnownContact: isKnownContact}
return accepted, isKnownContact
}
// ServerAuthInvalid ...
func (ach *AutoConnectionHandler) ServerAuthInvalid(err error) {
ach.authResultChannel <- channels.AuthChannelResult{Accepted: false, IsKnownContact: false}
}
// RegisterChannelHandler ... // RegisterChannelHandler ...
func (ach *AutoConnectionHandler) RegisterChannelHandler(ctype string, handler func() channels.Handler) { func (ach *AutoConnectionHandler) RegisterChannelHandler(ctype string, handler func() channels.Handler) {
_, exists := ach.handlerMap[ctype] _, exists := ach.handlerMap[ctype]

View File

@ -2,6 +2,7 @@ package connection
import ( import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/s-rah/go-ricochet/channels"
"github.com/s-rah/go-ricochet/utils" "github.com/s-rah/go-ricochet/utils"
"github.com/s-rah/go-ricochet/wire/control" "github.com/s-rah/go-ricochet/wire/control"
"testing" "testing"
@ -10,9 +11,10 @@ import (
// Test sending valid packets // Test sending valid packets
func TestInit(t *testing.T) { func TestInit(t *testing.T) {
ach := new(AutoConnectionHandler) ach := new(AutoConnectionHandler)
privateKey, err := utils.LoadPrivateKeyFromFile("../testing/private_key") ach.Init()
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler {
ach.Init(privateKey, "") return &channels.HiddenServiceAuthChannel{}
})
// Construct the Open Authentication Channel Message // Construct the Open Authentication Channel Message
messageBuilder := new(utils.MessageBuilder) messageBuilder := new(utils.MessageBuilder)

View File

@ -106,32 +106,25 @@ func (rc *Connection) Do(do func() error) error {
} }
// RequestOpenChannel sends an OpenChannel message to the remote client. // RequestOpenChannel sends an OpenChannel message to the remote client.
// and error is returned only if the requirements for opening this channel // An error is returned only if the requirements for opening this channel
// are not met on the local side (a nill error return does not mean the // are not met on the local side (a nil error return does not mean the
// channel was opened successfully) // channel was opened successfully, because channels open asynchronously).
func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error { func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) error {
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype)) rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
return rc.Do(func() error { return rc.Do(func() error {
chandler, err := handler.OnOpenChannelRequest(ctype)
if err != nil {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return err
}
// Check that we have the authentication already // Check that we have the authentication already
if chandler.RequiresAuthentication() != "none" { if handler.RequiresAuthentication() != "none" {
// Enforce Authentication Check. // Enforce Authentication Check.
_, authed := rc.Authentication[chandler.RequiresAuthentication()] _, authed := rc.Authentication[handler.RequiresAuthentication()]
if !authed { if !authed {
return utils.UnauthorizedActionError return utils.UnauthorizedActionError
} }
} }
channel, err := rc.channelManager.OpenChannelRequest(chandler) channel, err := rc.channelManager.OpenChannelRequest(handler)
if err != nil { if err != nil {
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err)) rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return err return err
} }
@ -139,18 +132,18 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error {
rc.SendRicochetPacket(rc.Conn, channel.ID, message) rc.SendRicochetPacket(rc.Conn, channel.ID, message)
} }
channel.DelegateAuthorization = func() { channel.DelegateAuthorization = func() {
rc.Authentication[chandler.Type()] = true rc.Authentication[handler.Type()] = true
} }
channel.CloseChannel = func() { channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID) rc.channelManager.RemoveChannel(channel.ID)
} }
response, err := chandler.OpenOutbound(channel) response, err := handler.OpenOutbound(channel)
if err == nil { if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype)) rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response) rc.SendRicochetPacket(rc.Conn, 0, response)
} else { } else {
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err)) rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID) rc.channelManager.RemoveChannel(channel.ID)
} }
return nil return nil

View File

@ -39,15 +39,32 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private
return utils.PrivateKeyNotSetError return utils.PrivateKeyNotSetError
} }
ach := new(AutoConnectionHandler) var authAllowed, authKnown bool
ach.Init(privateKey, ich.connection.RemoteHostname) var authHostname string
ach.SetServerAuthHandler(sach)
var authResult channels.AuthChannelResult onAuthValid := func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
go func() { authAllowed, authKnown = sach(hostname, publicKey)
authResult = ach.WaitForAuthenticationEvent() if authAllowed {
authHostname = hostname
}
ich.connection.Break() ich.connection.Break()
}() return authAllowed, authKnown
}
onAuthInvalid := func(err error) {
// err is ignored at the moment
ich.connection.Break()
}
ach := new(AutoConnectionHandler)
ach.Init()
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service",
func() channels.Handler {
return &channels.HiddenServiceAuthChannel{
PrivateKey: privateKey,
ServerAuthValid: onAuthValid,
ServerAuthInvalid: onAuthInvalid,
}
})
policy := policies.UnknownPurposeTimeout policy := policies.UnknownPurposeTimeout
err := policy.ExecuteAction(func() error { err := policy.ExecuteAction(func() error {
@ -55,8 +72,8 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private
}) })
if err == nil { if err == nil {
if authResult.Accepted == true { if authAllowed == true {
ich.connection.RemoteHostname = authResult.Hostname ich.connection.RemoteHostname = authHostname
return nil return nil
} }
return utils.ClientFailedToAuthenticateError return utils.ClientFailedToAuthenticateError

View File

@ -38,26 +38,34 @@ func (och *OutboundConnectionHandler) ProcessAuthAsClient(privateKey *rsa.Privat
} }
ach := new(AutoConnectionHandler) ach := new(AutoConnectionHandler)
ach.Init(privateKey, och.connection.RemoteHostname) ach.Init()
var result channels.AuthChannelResult var accepted, isKnownContact bool
go func() { authCallback := func(accept, known bool) {
err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service", ach) accepted = accept
if err != nil { isKnownContact = known
return // Cause the Process() call below to return
}
result = ach.WaitForAuthenticationEvent()
och.connection.Break() och.connection.Break()
}() }
err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service",
&channels.HiddenServiceAuthChannel{
PrivateKey: privateKey,
ServerHostname: och.connection.RemoteHostname,
ClientAuthResult: authCallback,
})
if err != nil {
return false, err
}
policy := policies.UnknownPurposeTimeout policy := policies.UnknownPurposeTimeout
err := policy.ExecuteAction(func() error { err = policy.ExecuteAction(func() error {
return och.connection.Process(ach) return och.connection.Process(ach)
}) })
if err == nil { if err == nil {
if result.Accepted == true { if accepted == true {
return result.IsKnownContact, nil return isKnownContact, nil
} }
} }
return false, utils.ServerRejectedClientConnectionError return false, utils.ServerRejectedClientConnectionError

View File

@ -45,7 +45,7 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string)
privateKey, _ := utils.LoadPrivateKeyFromFile(privateKeyFile) privateKey, _ := utils.LoadPrivateKeyFromFile(privateKeyFile)
echobot.messages = make(chan string) echobot.messages = make(chan string)
echobot.Init(privateKey, hostname) echobot.Init()
echobot.RegisterChannelHandler("im.ricochet.contact.request", func() channels.Handler { echobot.RegisterChannelHandler("im.ricochet.contact.request", func() channels.Handler {
contact := new(channels.ContactRequestChannel) contact := new(channels.ContactRequestChannel)
contact.Handler = echobot contact.Handler = echobot
@ -64,13 +64,13 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string)
go rc.Process(echobot) go rc.Process(echobot)
if !known { if !known {
err := rc.RequestOpenChannel("im.ricochet.contact.request", echobot) err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot})
if err != nil { if err != nil {
log.Printf("could not contact %s", err) log.Printf("could not contact %s", err)
} }
} }
rc.RequestOpenChannel("im.ricochet.chat", echobot) rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot})
for { for {
message := <-echobot.messages message := <-echobot.messages
log.Printf("Received Message: %s", message) log.Printf("Received Message: %s", message)