From d19102b2572b30a05c333e4d95f86b6e49396855 Mon Sep 17 00:00:00 2001 From: John Brooks Date: Sat, 16 Sep 2017 15:46:28 +0200 Subject: [PATCH] Pass channel handler directly to RequestOpenChannel RequestOpenChannel is the primary API to open a new outbound channel. It was written to take a connection.Handler and use OnOpenChannelRequest to get a channels.Handler to represent the new channel, which is the same path that inbound channels will take. Going through the global OnOpenChannelRequest method makes this much less flexible and prevents passing parameters to the new channel handler during creation. This also requires users of the API to know/find the connection handler, or worse, to boilerplate one into existence for their channel creation. Instead, I think this function should take a channels.Handler directly, so that the caller gets full control over the handler for their new channel. As part of that change, I've also moved the authentication logic in AutoConnectionHandler to be contained entirely within {In,Out}boundConnectionHandler. --- channels/channel.go | 7 ---- channels/hiddenserviceauthchannel.go | 29 ++++++-------- channels/hiddenserviceauthchannel_test.go | 23 ++---------- connection/autoconnectionhandler.go | 46 ++--------------------- connection/autoconnectionhandler_test.go | 8 ++-- connection/connection.go | 29 ++++++-------- connection/inboundconnectionhandler.go | 35 ++++++++++++----- connection/outboundconnectionhandler.go | 32 ++++++++++------ examples/echobot/main.go | 6 +-- 9 files changed, 83 insertions(+), 132 deletions(-) diff --git a/channels/channel.go b/channels/channel.go index 86c1e08..53b8632 100644 --- a/channels/channel.go +++ b/channels/channel.go @@ -10,13 +10,6 @@ const ( Outbound ) -// AuthChannelResult captures the result of an authentication flow -type AuthChannelResult struct { - Hostname string - Accepted bool - IsKnownContact bool -} - // Channel holds the state of a channel on an open connection type Channel struct { ID int32 diff --git a/channels/hiddenserviceauthchannel.go b/channels/hiddenserviceauthchannel.go index 54866db..cb90f4c 100644 --- a/channels/hiddenserviceauthchannel.go +++ b/channels/hiddenserviceauthchannel.go @@ -20,28 +20,21 @@ const ( // HiddenServiceAuthChannel wraps implementation of im.ricochet.auth.hidden-service" type HiddenServiceAuthChannel struct { - // Methods of Handler are called for events on this channel - Handler AuthChannelHandler // PrivateKey must be set for client-side authentication channels PrivateKey *rsa.PrivateKey // Server Hostname must be set for client-side authentication channels ServerHostname string + // Callbacks + ClientAuthResult func(accepted, isKnownContact bool) + ServerAuthValid func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) + ServerAuthInvalid func(err error) + // Internal state clientCookie, serverCookie [16]byte channel *Channel } -// AuthChannelHandler ... -type AuthChannelHandler interface { - // Client - ClientAuthResult(accepted bool, isKnownContact bool) - - // Server - ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) - ServerAuthInvalid(err error) -} - // Type returns the type string for this channel, e.g. "im.ricochet.chat". func (ah *HiddenServiceAuthChannel) Type() string { return "im.ricochet.auth.hidden-service" @@ -173,7 +166,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) { }) if err != nil { - ah.Handler.ServerAuthInvalid(err) + ah.ServerAuthInvalid(err) ah.channel.SendMessage([]byte{}) return } @@ -183,7 +176,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) { publicKey := rsa.PublicKey{} _, err = asn1.Unmarshal(res.GetProof().GetPublicKey(), &publicKey) if err != nil { - ah.Handler.ServerAuthInvalid(err) + ah.ServerAuthInvalid(err) ah.channel.SendMessage([]byte{}) return } @@ -194,7 +187,7 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) { if err == nil { // Signature is Good - accepted, isKnownContact := ah.Handler.ServerAuthValid(provisionalClientHostname, publicKey) + accepted, isKnownContact := ah.ServerAuthValid(provisionalClientHostname, publicKey) // Send Result messageBuilder := new(utils.MessageBuilder) @@ -206,11 +199,13 @@ func (ah *HiddenServiceAuthChannel) Packet(data []byte) { messageBuilder := new(utils.MessageBuilder) result := messageBuilder.AuthResult(false, false) ah.channel.SendMessage(result) - ah.Handler.ServerAuthInvalid(err) + ah.ServerAuthInvalid(err) } } else if res.GetResult() != nil && ah.channel.Direction == Outbound { - ah.Handler.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) + if ah.ClientAuthResult != nil { + ah.ClientAuthResult(res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) + } if res.GetResult().GetAccepted() { ah.channel.DelegateAuthorization() } diff --git a/channels/hiddenserviceauthchannel_test.go b/channels/hiddenserviceauthchannel_test.go index 8f3736f..e219100 100644 --- a/channels/hiddenserviceauthchannel_test.go +++ b/channels/hiddenserviceauthchannel_test.go @@ -110,34 +110,16 @@ func TestAuthenticationOpenOutbound(t *testing.T) { } -type SimpleTestAuthHandler struct { -} - -// Client -func (stah *SimpleTestAuthHandler) ClientAuthResult(accepted bool, isKnownContact bool) { - -} - -// Server -func (stah *SimpleTestAuthHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { - return true, true -} - -func (stah *SimpleTestAuthHandler) ServerAuthInvalid(err error) { - -} - func TestAuthenticationOpenOutboundResult(t *testing.T) { privateKey, _ := utils.LoadPrivateKeyFromFile("../testing/private_key") authHandlerA := new(HiddenServiceAuthChannel) authHandlerB := new(HiddenServiceAuthChannel) - simpleTestAuthHandler := new(SimpleTestAuthHandler) authHandlerA.ServerHostname = "kwke2hntvyfqm7dr" authHandlerA.PrivateKey = privateKey - authHandlerA.Handler = simpleTestAuthHandler + authHandlerA.ClientAuthResult = func(accepted, known bool) {} channelA := Channel{ID: 1, Direction: Outbound} channelA.SendMessage = func(message []byte) { authHandlerB.Packet(message) @@ -150,7 +132,8 @@ func TestAuthenticationOpenOutboundResult(t *testing.T) { authHandlerB.ServerHostname = "kwke2hntvyfqm7dr" authHandlerB.PrivateKey = privateKey - authHandlerB.Handler = simpleTestAuthHandler + authHandlerB.ServerAuthValid = func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { return true, true } + authHandlerB.ServerAuthInvalid = func(err error) { t.Error("server received invalid auth") } channelB := Channel{ID: 1, Direction: Inbound} channelB.SendMessage = func(message []byte) { authHandlerA.Packet(message) diff --git a/connection/autoconnectionhandler.go b/connection/autoconnectionhandler.go index 29aad54..cbb96c3 100644 --- a/connection/autoconnectionhandler.go +++ b/connection/autoconnectionhandler.go @@ -1,7 +1,6 @@ package connection import ( - "crypto/rsa" "github.com/s-rah/go-ricochet/channels" "github.com/s-rah/go-ricochet/utils" ) @@ -16,30 +15,14 @@ import ( // AutoConnectionHandler's behavior (such as adding new channel types, or reacting // to connection close events), this type can be embedded in the type that it serves. type AutoConnectionHandler struct { - handlerMap map[string]func() channels.Handler - connection *Connection - authResultChannel chan channels.AuthChannelResult - sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) + handlerMap map[string]func() channels.Handler + connection *Connection } // Init ... // TODO: Split this into client and server init -func (ach *AutoConnectionHandler) Init(privateKey *rsa.PrivateKey, serverHostname string) { - +func (ach *AutoConnectionHandler) Init() { ach.handlerMap = make(map[string]func() channels.Handler) - ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler { - hsau := new(channels.HiddenServiceAuthChannel) - hsau.PrivateKey = privateKey - hsau.Handler = ach - hsau.ServerHostname = serverHostname - return hsau - }) - ach.authResultChannel = make(chan channels.AuthChannelResult) -} - -// SetServerAuthHandler ... -func (ach *AutoConnectionHandler) SetServerAuthHandler(sach func(hostname string, publicKey rsa.PublicKey) (allowed, known bool)) { - ach.sach = sach } // OnReady ... @@ -51,29 +34,6 @@ func (ach *AutoConnectionHandler) OnReady(oc *Connection) { func (ach *AutoConnectionHandler) OnClosed(err error) { } -// WaitForAuthenticationEvent ... -func (ach *AutoConnectionHandler) WaitForAuthenticationEvent() channels.AuthChannelResult { - return <-ach.authResultChannel -} - -// ClientAuthResult ... -func (ach *AutoConnectionHandler) ClientAuthResult(accepted bool, isKnownContact bool) { - ach.authResultChannel <- channels.AuthChannelResult{Accepted: accepted, IsKnownContact: isKnownContact} -} - -// ServerAuthValid ... -func (ach *AutoConnectionHandler) ServerAuthValid(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { - // Do something - accepted, isKnownContact := ach.sach(hostname, publicKey) - ach.authResultChannel <- channels.AuthChannelResult{Hostname: hostname, Accepted: accepted, IsKnownContact: isKnownContact} - return accepted, isKnownContact -} - -// ServerAuthInvalid ... -func (ach *AutoConnectionHandler) ServerAuthInvalid(err error) { - ach.authResultChannel <- channels.AuthChannelResult{Accepted: false, IsKnownContact: false} -} - // RegisterChannelHandler ... func (ach *AutoConnectionHandler) RegisterChannelHandler(ctype string, handler func() channels.Handler) { _, exists := ach.handlerMap[ctype] diff --git a/connection/autoconnectionhandler_test.go b/connection/autoconnectionhandler_test.go index 977a022..572c2b5 100644 --- a/connection/autoconnectionhandler_test.go +++ b/connection/autoconnectionhandler_test.go @@ -2,6 +2,7 @@ package connection import ( "github.com/golang/protobuf/proto" + "github.com/s-rah/go-ricochet/channels" "github.com/s-rah/go-ricochet/utils" "github.com/s-rah/go-ricochet/wire/control" "testing" @@ -10,9 +11,10 @@ import ( // Test sending valid packets func TestInit(t *testing.T) { ach := new(AutoConnectionHandler) - privateKey, err := utils.LoadPrivateKeyFromFile("../testing/private_key") - - ach.Init(privateKey, "") + ach.Init() + ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", func() channels.Handler { + return &channels.HiddenServiceAuthChannel{} + }) // Construct the Open Authentication Channel Message messageBuilder := new(utils.MessageBuilder) diff --git a/connection/connection.go b/connection/connection.go index 982a0f9..42af568 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -106,32 +106,25 @@ func (rc *Connection) Do(do func() error) error { } // RequestOpenChannel sends an OpenChannel message to the remote client. -// and error is returned only if the requirements for opening this channel -// are not met on the local side (a nill error return does not mean the -// channel was opened successfully) -func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error { +// An error is returned only if the requirements for opening this channel +// are not met on the local side (a nil error return does not mean the +// channel was opened successfully, because channels open asynchronously). +func (rc *Connection) RequestOpenChannel(ctype string, handler channels.Handler) error { rc.traceLog(fmt.Sprintf("requesting open channel of type %s", ctype)) return rc.Do(func() error { - chandler, err := handler.OnOpenChannelRequest(ctype) - - if err != nil { - rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) - return err - } - // Check that we have the authentication already - if chandler.RequiresAuthentication() != "none" { + if handler.RequiresAuthentication() != "none" { // Enforce Authentication Check. - _, authed := rc.Authentication[chandler.RequiresAuthentication()] + _, authed := rc.Authentication[handler.RequiresAuthentication()] if !authed { return utils.UnauthorizedActionError } } - channel, err := rc.channelManager.OpenChannelRequest(chandler) + channel, err := rc.channelManager.OpenChannelRequest(handler) if err != nil { - rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err)) + rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) return err } @@ -139,18 +132,18 @@ func (rc *Connection) RequestOpenChannel(ctype string, handler Handler) error { rc.SendRicochetPacket(rc.Conn, channel.ID, message) } channel.DelegateAuthorization = func() { - rc.Authentication[chandler.Type()] = true + rc.Authentication[handler.Type()] = true } channel.CloseChannel = func() { rc.SendRicochetPacket(rc.Conn, channel.ID, []byte{}) rc.channelManager.RemoveChannel(channel.ID) } - response, err := chandler.OpenOutbound(channel) + response, err := handler.OpenOutbound(channel) if err == nil { rc.traceLog(fmt.Sprintf("requested open channel of type %s", ctype)) rc.SendRicochetPacket(rc.Conn, 0, response) } else { - rc.traceLog(fmt.Sprintf("failed to reqeust open channel of type %v", err)) + rc.traceLog(fmt.Sprintf("failed to request open channel of type %v", err)) rc.channelManager.RemoveChannel(channel.ID) } return nil diff --git a/connection/inboundconnectionhandler.go b/connection/inboundconnectionhandler.go index 5b6111d..953726b 100644 --- a/connection/inboundconnectionhandler.go +++ b/connection/inboundconnectionhandler.go @@ -39,15 +39,32 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private return utils.PrivateKeyNotSetError } - ach := new(AutoConnectionHandler) - ach.Init(privateKey, ich.connection.RemoteHostname) - ach.SetServerAuthHandler(sach) + var authAllowed, authKnown bool + var authHostname string - var authResult channels.AuthChannelResult - go func() { - authResult = ach.WaitForAuthenticationEvent() + onAuthValid := func(hostname string, publicKey rsa.PublicKey) (allowed, known bool) { + authAllowed, authKnown = sach(hostname, publicKey) + if authAllowed { + authHostname = hostname + } ich.connection.Break() - }() + return authAllowed, authKnown + } + onAuthInvalid := func(err error) { + // err is ignored at the moment + ich.connection.Break() + } + + ach := new(AutoConnectionHandler) + ach.Init() + ach.RegisterChannelHandler("im.ricochet.auth.hidden-service", + func() channels.Handler { + return &channels.HiddenServiceAuthChannel{ + PrivateKey: privateKey, + ServerAuthValid: onAuthValid, + ServerAuthInvalid: onAuthInvalid, + } + }) policy := policies.UnknownPurposeTimeout err := policy.ExecuteAction(func() error { @@ -55,8 +72,8 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private }) if err == nil { - if authResult.Accepted == true { - ich.connection.RemoteHostname = authResult.Hostname + if authAllowed == true { + ich.connection.RemoteHostname = authHostname return nil } return utils.ClientFailedToAuthenticateError diff --git a/connection/outboundconnectionhandler.go b/connection/outboundconnectionhandler.go index 999a0d8..af81c24 100644 --- a/connection/outboundconnectionhandler.go +++ b/connection/outboundconnectionhandler.go @@ -38,26 +38,34 @@ func (och *OutboundConnectionHandler) ProcessAuthAsClient(privateKey *rsa.Privat } ach := new(AutoConnectionHandler) - ach.Init(privateKey, och.connection.RemoteHostname) + ach.Init() - var result channels.AuthChannelResult - go func() { - err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service", ach) - if err != nil { - return - } - result = ach.WaitForAuthenticationEvent() + var accepted, isKnownContact bool + authCallback := func(accept, known bool) { + accepted = accept + isKnownContact = known + // Cause the Process() call below to return och.connection.Break() - }() + } + + err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service", + &channels.HiddenServiceAuthChannel{ + PrivateKey: privateKey, + ServerHostname: och.connection.RemoteHostname, + ClientAuthResult: authCallback, + }) + if err != nil { + return false, err + } policy := policies.UnknownPurposeTimeout - err := policy.ExecuteAction(func() error { + err = policy.ExecuteAction(func() error { return och.connection.Process(ach) }) if err == nil { - if result.Accepted == true { - return result.IsKnownContact, nil + if accepted == true { + return isKnownContact, nil } } return false, utils.ServerRejectedClientConnectionError diff --git a/examples/echobot/main.go b/examples/echobot/main.go index f87080d..485c803 100644 --- a/examples/echobot/main.go +++ b/examples/echobot/main.go @@ -45,7 +45,7 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string) privateKey, _ := utils.LoadPrivateKeyFromFile(privateKeyFile) echobot.messages = make(chan string) - echobot.Init(privateKey, hostname) + echobot.Init() echobot.RegisterChannelHandler("im.ricochet.contact.request", func() channels.Handler { contact := new(channels.ContactRequestChannel) contact.Handler = echobot @@ -64,13 +64,13 @@ func (echobot *RicochetEchoBot) Connect(privateKeyFile string, hostname string) go rc.Process(echobot) if !known { - err := rc.RequestOpenChannel("im.ricochet.contact.request", echobot) + err := rc.RequestOpenChannel("im.ricochet.contact.request", &channels.ContactRequestChannel{Handler: echobot}) if err != nil { log.Printf("could not contact %s", err) } } - rc.RequestOpenChannel("im.ricochet.chat", echobot) + rc.RequestOpenChannel("im.ricochet.chat", &channels.ChatChannel{Handler: echobot}) for { message := <-echobot.messages log.Printf("Received Message: %s", message)