Pass channel handler directly to RequestOpenChannel

RequestOpenChannel is the primary API to open a new outbound channel. It
was written to take a connection.Handler and use OnOpenChannelRequest
to get a channels.Handler to represent the new channel, which is the
same path that inbound channels will take.

Going through the global OnOpenChannelRequest method makes this much
less flexible and prevents passing parameters to the new channel handler
during creation.  This also requires users of the API to know/find the
connection handler, or worse, to boilerplate one into existence for their
channel creation.

Instead, I think this function should take a channels.Handler directly,
so that the caller gets full control over the handler for their new
channel.

As part of that change, I've also moved the authentication logic in
AutoConnectionHandler to be contained entirely within
{In,Out}boundConnectionHandler.
This commit is contained in:
John Brooks 2017-09-16 15:46:28 +02:00 committed by Sarah Jamie Lewis
parent 1ed9265866
commit d19102b257
9 changed files with 83 additions and 132 deletions

View File

@ -10,13 +10,6 @@ const (
Outbound
)
// AuthChannelResult captures the result of an authentication flow
type AuthChannelResult struct {
Hostname string
Accepted bool
IsKnownContact bool
}
// Channel holds the state of a channel on an open connection
type Channel struct {
ID int32

View File

@ -20,28 +20,21 @@ const (
// HiddenServiceAuthChannel wraps implementation of im.ricochet.auth.hidden-service"
type HiddenServiceAuthChannel struct {
// Methods of Handler are called for events on this channel
Handler AuthChannelHandler
// PrivateKey must be set for client-side authentication channels
PrivateKey *rsa.PrivateKey
// Server Hostname must be set for client-side authentication channels
ServerHostname string
// Callbacks
ClientAuthResult func(accepted, isKnownContact bool)
ServerAuthValid func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
ServerAuthInvalid func(err error)
// Internal state
clientCookie, serverCookie [16]byte
channel *Channel
}
// AuthChannelHandler ...
type AuthChannelHandler interface {
// Client
ClientAuthResult(accepted bool, isKnownContact bool)
// Server
ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
ServerAuthInvalid(err error)
}
// Type returns the type string for this channel, e.g. "im.ricochet.chat".
func (ah *HiddenServiceAuthChannel) Type() string {
return "im.ricochet.auth.hidden-service"
@ -173,7 +166,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
})
if err != nil {
ah.Handler.ServerAuthInvalid(err)
ah.ServerAuthInvalid(err)
ah.channel.SendMessage([]byte{})
return
}
@ -183,7 +176,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
publicKey := rsa.PublicKey{}
_, err = asn1.Unmarshal(res.GetProof().GetPublicKey(), &publicKey)
if err != nil {
ah.Handler.ServerAuthInvalid(err)
ah.ServerAuthInvalid(err)
ah.channel.SendMessage([]byte{})
return
}
@ -194,7 +187,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
if err == nil {
// Signature is Good
accepted, isKnownContact := ah.Handler.ServerAuthValid(provisionalClientHostname, publicKey)
accepted, isKnownContact := ah.ServerAuthValid(provisionalClientHostname, publicKey)
// Send Result
messageBuilder := new(utils.MessageBuilder)
@ -206,11 +199,13 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) {
messageBuilder := new(utils.MessageBuilder)
result := messageBuilder.AuthResult(false, false)
ah.channel.SendMessage(result)
ah.Handler.ServerAuthInvalid(err)
ah.ServerAuthInvalid(err)
}
} else if res.GetResult() != nil && ah.channel.Direction == Outbound {
ah.Handler.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
if ah.ClientAuthResult != nil {
ah.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
}
if res.GetResult().GetAccepted() {
ah.channel.DelegateAuthorization()
}

View File

@ -110,34 +110,16 @@ func TestAuthenticationOpenOutbound(t *testing.T) {
}
type SimpleTestAuthHandler struct {
}
// Client
func (stah *SimpleTestAuthHandler) ClientAuthResult(accepted bool, isKnownContact bool) {
}
// Server
func (stah *SimpleTestAuthHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
return true, true
}
func (stah *SimpleTestAuthHandler) ServerAuthInvalid(err error) {
}
func TestAuthenticationOpenOutboundResult(t *testing.T) {
privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key")
authHandlerA := new(HiddenServiceAuthChannel)
authHandlerB := new(HiddenServiceAuthChannel)
simpleTestAuthHandler := new(SimpleTestAuthHandler)
authHandlerA.ServerHostname = "kwke2hntvyfqm7dr"
authHandlerA.PrivateKey = privateKey
authHandlerA.Handler = simpleTestAuthHandler
authHandlerA.ClientAuthResult = func(accepted, known bool) {}
channelA := Channel{ID: 1, Direction: Outbound}
channelA.SendMessage = func(message []byte) {
authHandlerB.Packet(message)
@ -150,7 +132,8 @@ func TestAuthenticationOpenOutboundResult(t *testing.T) {
authHandlerB.ServerHostname = "kwke2hntvyfqm7dr"
authHandlerB.PrivateKey = privateKey
authHandlerB.Handler = simpleTestAuthHandler
authHandlerB.ServerAuthValid = func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { return true, true }
authHandlerB.ServerAuthInvalid = func(err error) { t.Error("server received invalid auth") }
channelB := Channel{ID: 1, Direction: Inbound}
channelB.SendMessage = func(message []byte) {
authHandlerA.Packet(message)

View File

@ -1,7 +1,6 @@
package connection
import (
"crypto/rsa"
"github.com/s-rah/go-ricochet/channels"
"github.com/s-rah/go-ricochet/utils"
)
@ -16,30 +15,14 @@ import (
// AutoConnectionHandler's behavior (such as adding new channel types, or reacting
// to connection close events), this type can be embedded in the type that it serves.
type AutoConnectionHandler struct {
handlerMap map[string]func() channels.Handler
connection *Connection
authResultChannel chan channels.AuthChannelResult
sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)
handlerMap map[string]func() channels.Handler
connection *Connection
}
// Init ...
// TODO: Split this into client and server init
func (ach *AutoConnectionHandler) Init(privateKey *rsa.PrivateKey, serverHostname string) {
func (ach *AutoConnectionHandler) Init() {
ach.handlerMap = make(map[string]func() channels.Handler)
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler {
hsau := new(channels.HiddenServiceAuthChannel)
hsau.PrivateKey = privateKey
hsau.Handler = ach
hsau.ServerHostname = serverHostname
return hsau
})
ach.authResultChannel = make(chan channels.AuthChannelResult)
}
// SetServerAuthHandler ...
func (ach *AutoConnectionHandler) SetServerAuthHandler(sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)) {
ach.sach = sach
}
// OnReady ...
@ -51,29 +34,6 @@ func (ach *AutoConnectionHandler) OnReady(oc *Connection) {
func (ach *AutoConnectionHandler) OnClosed(err error) {
}
// WaitForAuthenticationEvent ...
func (ach *AutoConnectionHandler) WaitForAuthenticationEvent() channels.AuthChannelResult {
return <-ach.authResultChannel
}
// ClientAuthResult ...
func (ach *AutoConnectionHandler) ClientAuthResult(accepted bool, isKnownContact bool) {
ach.authResultChannel <- channels.AuthChannelResult{Accepted: accepted, IsKnownContact: isKnownContact}
}
// ServerAuthValid ...
func (ach *AutoConnectionHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
// Do something
accepted, isKnownContact := ach.sach(hostname, publicKey)
ach.authResultChannel <- channels.AuthChannelResult{Hostname: hostname, Accepted: accepted, IsKnownContact: isKnownContact}
return accepted, isKnownContact
}
// ServerAuthInvalid ...
func (ach *AutoConnectionHandler) ServerAuthInvalid(err error) {
ach.authResultChannel <- channels.AuthChannelResult{Accepted: false, IsKnownContact: false}
}
// RegisterChannelHandler ...
func (ach *AutoConnectionHandler) RegisterChannelHandler(ctype string, handler func() channels.Handler) {
_, exists := ach.handlerMap[ctype]

View File

@ -2,6 +2,7 @@ package connection
import (
"github.com/golang/protobuf/proto"
"github.com/s-rah/go-ricochet/channels"
"github.com/s-rah/go-ricochet/utils"
"github.com/s-rah/go-ricochet/wire/control"
"testing"
@ -10,9 +11,10 @@ import (
// Test sending valid packets
func TestInit(t *testing.T) {
ach := new(AutoConnectionHandler)
privateKey, err := utils.LoadPrivateKeyFromFile("../testing/private_key")
ach.Init(privateKey, "")
ach.Init()
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler {
return &channels.HiddenServiceAuthChannel{}
})
// Construct the Open Authentication Channel Message
messageBuilder := new(utils.MessageBuilder)

View File

@ -106,32 +106,25 @@ func (rc *Connection) Do(do func() error) error {
}
// RequestOpenChannel sends an OpenChannel message to the remote client.
// and error is returned only if the requirements for opening this channel
// are not met on the local side (a nill error return does not mean the
// channel was opened successfully)
func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error {
// An error is returned only if the requirements for opening this channel
// are not met on the local side (a nil error return does not mean the
// channel was opened successfully, because channels open asynchronously).
func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) error {
rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype))
return rc.Do(func() error {
chandler, err := handler.OnOpenChannelRequest(ctype)
if err != nil {
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return err
}
// Check that we have the authentication already
if chandler.RequiresAuthentication() != "none" {
if handler.RequiresAuthentication() != "none" {
// Enforce Authentication Check.
_, authed := rc.Authentication[chandler.RequiresAuthentication()]
_, authed := rc.Authentication[handler.RequiresAuthentication()]
if !authed {
return utils.UnauthorizedActionError
}
}
channel, err := rc.channelManager.OpenChannelRequest(chandler)
channel, err := rc.channelManager.OpenChannelRequest(handler)
if err != nil {
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err))
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
return err
}
@ -139,18 +132,18 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error {
rc.SendRicochetPacket(rc.Conn, channel.ID, message)
}
channel.DelegateAuthorization = func() {
rc.Authentication[chandler.Type()] = true
rc.Authentication[handler.Type()] = true
}
channel.CloseChannel = func() {
rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{})
rc.channelManager.RemoveChannel(channel.ID)
}
response, err := chandler.OpenOutbound(channel)
response, err := handler.OpenOutbound(channel)
if err == nil {
rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype))
rc.SendRicochetPacket(rc.Conn, 0, response)
} else {
rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err))
rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err))
rc.channelManager.RemoveChannel(channel.ID)
}
return nil

View File

@ -39,15 +39,32 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private
return utils.PrivateKeyNotSetError
}
ach := new(AutoConnectionHandler)
ach.Init(privateKey, ich.connection.RemoteHostname)
ach.SetServerAuthHandler(sach)
var authAllowed, authKnown bool
var authHostname string
var authResult channels.AuthChannelResult
go func() {
authResult = ach.WaitForAuthenticationEvent()
onAuthValid := func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) {
authAllowed, authKnown = sach(hostname, publicKey)
if authAllowed {
authHostname = hostname
}
ich.connection.Break()
}()
return authAllowed, authKnown
}
onAuthInvalid := func(err error) {
// err is ignored at the moment
ich.connection.Break()
}
ach := new(AutoConnectionHandler)
ach.Init()
ach.RegisterChannelHandler("im.ricochet.auth.hidden-service",
func() channels.Handler {
return &channels.HiddenServiceAuthChannel{
PrivateKey: privateKey,
ServerAuthValid: onAuthValid,
ServerAuthInvalid: onAuthInvalid,
}
})
policy := policies.UnknownPurposeTimeout
err := policy.ExecuteAction(func() error {
@ -55,8 +72,8 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private
})
if err == nil {
if authResult.Accepted == true {
ich.connection.RemoteHostname = authResult.Hostname
if authAllowed == true {
ich.connection.RemoteHostname = authHostname
return nil
}
return utils.ClientFailedToAuthenticateError

View File

@ -38,26 +38,34 @@ func (och *OutboundConnectionHandler) ProcessAuthAsClient(privateKey *rsa.Privat
}
ach := new(AutoConnectionHandler)
ach.Init(privateKey, och.connection.RemoteHostname)
ach.Init()
var result channels.AuthChannelResult
go func() {
err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service", ach)
if err != nil {
return
}
result = ach.WaitForAuthenticationEvent()
var accepted, isKnownContact bool
authCallback := func(accept, known bool) {
accepted = accept
isKnownContact = known
// Cause the Process() call below to return
och.connection.Break()
}()
}
err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service",
&channels.HiddenServiceAuthChannel{
PrivateKey: privateKey,
ServerHostname: och.connection.RemoteHostname,
ClientAuthResult: authCallback,
})
if err != nil {
return false, err
}
policy := policies.UnknownPurposeTimeout
err := policy.ExecuteAction(func() error {
err = policy.ExecuteAction(func() error {
return och.connection.Process(ach)
})
if err == nil {
if result.Accepted == true {
return result.IsKnownContact, nil
if accepted == true {
return isKnownContact, nil
}
}
return false, utils.ServerRejectedClientConnectionError

View File

@ -45,7 +45,7 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string)
privateKey, _ := utils.LoadPrivateKeyFromFile(privateKeyFile)
echobot.messages = make(chan string)
echobot.Init(privateKey, hostname)
echobot.Init()
echobot.RegisterChannelHandler("im.ricochet.contact.request", func() channels.Handler {
contact := new(channels.ContactRequestChannel)
contact.Handler = echobot
@ -64,13 +64,13 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string)
go rc.Process(echobot)
if !known {
err := rc.RequestOpenChannel("im.ricochet.contact.request", echobot)
err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot})
if err != nil {
log.Printf("could not contact %s", err)
}
}
rc.RequestOpenChannel("im.ricochet.chat", echobot)
rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot})
for {
message := <-echobot.messages
log.Printf("Received Message: %s", message)