diff --git a/connection/inboundconnectionhandler.go b/connection/inboundconnectionhandler.go index 953726b..e74bd47 100644 --- a/connection/inboundconnectionhandler.go +++ b/connection/inboundconnectionhandler.go @@ -5,6 +5,7 @@ import ( "github.com/s-rah/go-ricochet/channels" "github.com/s-rah/go-ricochet/policies" "github.com/s-rah/go-ricochet/utils" + "sync" ) // InboundConnectionHandler is a convieniance wrapper for handling inbound @@ -39,6 +40,8 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private return utils.PrivateKeyNotSetError } + var breakOnce sync.Once + var authAllowed, authKnown bool var authHostname string @@ -47,12 +50,12 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private if authAllowed { authHostname = hostname } - ich.connection.Break() + breakOnce.Do(func() { go ich.connection.Break() }) return authAllowed, authKnown } onAuthInvalid := func(err error) { // err is ignored at the moment - ich.connection.Break() + breakOnce.Do(func() { go ich.connection.Break() }) } ach := new(AutoConnectionHandler) @@ -66,6 +69,9 @@ func (ich *InboundConnectionHandler) ProcessAuthAsServer(privateKey *rsa.Private } }) + // Ensure that the call to Process() cannot outlive this function, + // particularly for the case where the policy timeout expires + defer breakOnce.Do(func() { ich.connection.Break() }) policy := policies.UnknownPurposeTimeout err := policy.ExecuteAction(func() error { return ich.connection.Process(ach) diff --git a/connection/outboundconnectionhandler.go b/connection/outboundconnectionhandler.go index c966140..be13328 100644 --- a/connection/outboundconnectionhandler.go +++ b/connection/outboundconnectionhandler.go @@ -5,6 +5,7 @@ import ( "github.com/s-rah/go-ricochet/channels" "github.com/s-rah/go-ricochet/policies" "github.com/s-rah/go-ricochet/utils" + "sync" ) // OutboundConnectionHandler is a convieniance wrapper for handling outbound @@ -40,33 +41,50 @@ func (och *OutboundConnectionHandler) ProcessAuthAsClient(privateKey *rsa.Privat ach := new(AutoConnectionHandler) ach.Init() + // Make sure that calls to Break in this function cannot race + var breakOnce sync.Once + var accepted, isKnownContact bool authCallback := func(accept, known bool) { accepted = accept isKnownContact = known - // Cause the Process() call below to return - och.connection.Break() + // Cause the Process() call below to return. + // If Break() is called from here, it _must_ use go, because this will + // execute in the Process goroutine, and Break() will deadlock. + breakOnce.Do(func() { go och.connection.Break() }) } - _, err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service", - &channels.HiddenServiceAuthChannel{ - PrivateKey: privateKey, - ServerHostname: och.connection.RemoteHostname, - ClientAuthResult: authCallback, + processResult := make(chan error, 1) + go func() { + // Break Process() if timed out; no-op if Process returned a conn error + defer func() { breakOnce.Do(func() { och.connection.Break() }) }() + policy := policies.UnknownPurposeTimeout + err := policy.ExecuteAction(func() error { + return och.connection.Process(ach) }) + processResult <- err + }() + + err := och.connection.Do(func() error { + _, err := och.connection.RequestOpenChannel("im.ricochet.auth.hidden-service", + &channels.HiddenServiceAuthChannel{ + PrivateKey: privateKey, + ServerHostname: och.connection.RemoteHostname, + ClientAuthResult: authCallback, + }) + return err + }) if err != nil { + breakOnce.Do(func() { och.connection.Break() }) return false, err } - policy := policies.UnknownPurposeTimeout - err = policy.ExecuteAction(func() error { - return och.connection.Process(ach) - }) + if err = <-processResult; err != nil { + return false, err + } - if err == nil { - if accepted == true { - return isKnownContact, nil - } + if accepted == true { + return isKnownContact, nil } return false, utils.ServerRejectedClientConnectionError }